├── .DS_Store ├── .gitignore ├── AdaMerging.png ├── LICENSE ├── README.md ├── checkpoints └── README.md ├── data └── README.md ├── logs └── ViT-B-32 │ ├── log_20230920_214801_Layer_wise_AdaMergingPP.txt │ ├── log_20230921_193533_Layer_wise_AdaMerging.txt │ ├── log_20230921_202228_Task_wise_AdaMerging.txt │ └── log_20230921_205130_Task_wise_AdaMergingPP.txt └── src ├── args.py ├── datasets ├── cars.py ├── common.py ├── dtd.py ├── eurosat.py ├── gtsrb.py ├── mnist.py ├── registry.py ├── resisc45.py ├── sun397.py ├── svhn.py └── templates.py ├── eval.py ├── heads.py ├── main_layer_wise_adamerging.py ├── main_layer_wise_adamergingpp.py ├── main_task_arithmetic.py ├── main_task_wise_adamerging.py ├── main_task_wise_adamergingpp.py ├── main_ties_merging.py ├── merging_cofficient.py ├── modeling.py ├── task_vectors.py ├── ties_merging_utils.py └── utils.py /.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EnnengYang/AdaMerging/cea1165ad900113b8a52ae5ff190626576635f9b/.DS_Store -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | 2 | .DS_Store 3 | -------------------------------------------------------------------------------- /AdaMerging.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EnnengYang/AdaMerging/cea1165ad900113b8a52ae5ff190626576635f9b/AdaMerging.png -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Enneng Yang 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # AdaMerging 2 | 3 | A repository of **'[AdaMerging: Adaptive Model Merging for Multi-Task Learning](https://arxiv.org/abs/2310.02575). ICLR, 2024.'**. 4 | 5 | 6 | ## Abstract 7 | > Multi-task learning (MTL) aims to empower a model to tackle multiple tasks simultaneously. A recent development known as task arithmetic has revealed that several models, each fine-tuned for distinct tasks, can be directly merged into a single model to execute MTL without necessitating a retraining process using the initial training data. Nevertheless, this direct addition of models often leads to a significant deterioration in the overall performance of the merged model. This decline occurs due to potential conflicts and intricate correlations among the multiple tasks. Consequently, the challenge emerges of how to merge pre-trained models more effectively without using their original training data. This paper introduces an innovative technique called Adaptive Model Merging (AdaMerging). This approach aims to autonomously learn the coefficients for model merging, either in a task-wise or layer-wise manner, without relying on the original training data. Specifically, our AdaMerging method operates as an automatic, unsupervised task arithmetic scheme. It leverages entropy minimization on unlabeled test samples from the multi-task setup as a surrogate objective function to iteratively refine the merging coefficients of the multiple models. Our experimental findings across eight tasks demonstrate the efficacy of the AdaMerging scheme we put forth. Compared to the current state-of-the-art (SOTA) task arithmetic merging scheme, AdaMerging showcases a remarkable 11\% improvement in performance. Notably, AdaMerging also exhibits superior generalization capabilities when applied to unseen downstream tasks. Furthermore, it displays a significantly enhanced robustness to data distribution shifts that may occur during the testing phase. 8 | 9 |
10 | AdaMerging 11 |
12 | 13 | 14 | ## Citation 15 | If you find our paper or this resource helpful, please consider cite: 16 | ``` 17 | @article{AdaMerging_ICLR_2024, 18 | title={AdaMerging: Adaptive Model Merging for Multi-Task Learning}, 19 | author={Yang, Enneng and Wang, Zhenyi and Shen, Li and Liu, Shiwei and Guo, Guibing and Wang, Xingwei and Tao, Dacheng}, 20 | journal={The Twelfth International Conference on Learning Representations}, 21 | year={2024} 22 | } 23 | ``` 24 | Thanks! 25 | 26 | 27 | ## Datasets 28 | Refer to dataset processing in the [task_vectors](https://github.com/mlfoundations/task_vectors). 29 | 30 | Or you can download the processed data from [Baidu Cloud disk](https://pan.baidu.com/s/1w0Z2UVv3NVmqDhjH8WTOJQ?pwd=kvg6). 31 | 32 | 33 | ## Checkpoints 34 | 35 | You can download the fine-tuned checkpoints from the [task_vectors#checkpoints](https://github.com/mlfoundations/task_vectors#checkpoints). 36 | The Google Drive folder is: [task_vectors_checkpoints](https://drive.google.com/drive/folders/1u_Tva6x0p6oxu5Eo0ZZsf-520Cc_3MKw) 37 | 38 | *Note: When using ```torch.load(xxx_checkpoint).state_dict()``` fails, you can try ```pickle.load(open(xxx_checkpoint, 'rb')).state_dict()```.* 39 | 40 | ## Code 41 | > [!Note] 42 | > We noticed that our [AdaMerging](https://github.com/tanganke/fusion_bench/tree/main/fusion_bench/method/adamerging) method has been implemented in [FusionBench](https://github.com/tanganke/fusion_bench), a model merging benchmark. Thanks to its excellent memory management, AdaMerging is highly efficient within this framework, with the entire training process taking **only a few minutes**. 43 | 44 | ### Train 45 | 46 | **If you want to train AdaMerging, run this part of the code. If you want to load the trained merging coefficients directly, refer to the Eval section.** 47 | 48 | First enter the root directory of the source code. 49 | > cd root_path/src/ 50 | 51 | Run Task Arithmetic [paper](https://arxiv.org/abs/2212.04089) 52 | > python main_task_arithmetic.py 53 | 54 | Run TIES-MERGING [paper](https://arxiv.org/abs/2306.01708) 55 | > python main_ties_merging.py 56 | 57 | Run Task-wise AdaMerging (Ours) 58 | > python main_task_wise_adamerging.py 59 | 60 | Run Task-wise AdaMerging++ (Ours) 61 | > python main_task_wise_adamergingpp.py 62 | 63 | Run Layer-wise AdaMerging (Ours) 64 | > python main_layer_wise_adamerging.py 65 | 66 | Run Layer-wise AdaMerging++ (Ours) 67 | > python main_layer_wise_adamergingpp.py 68 | 69 | *Note: Due to machine memory limitations, our implementation reloaded the dataset at each step, which resulted in a significant amount of additional time. If your machine has enough memory, you can load all the data before optimizing the merging coefficients, which will speed up the training significantly (i.e., the merging coefficients can be trained in a matter of minutes).* 70 | 71 | ### Eval 72 | Alternatively, you can load our trained merge coefficients, which can be found in the *[merging_coefficient.py](https://github.com/EnnengYang/AdaMerging/blob/main/src/merging_cofficient.py)* file. The general process is as follows: 73 | 74 | ``` 75 | # load 76 | from merging_cofficient import get_merging_cofficients 77 | ralpha = get_merging_cofficients(method_name, model_name) 78 | self.alpha = torch.Tensor(ralpha) 79 | 80 | # wrap 81 | if self.alpha.size()[0] == 1:# task-wise merging 82 | params = tuple(sum(tuple(pi * alphai for pi, alphai in zip(p, self.alpha[0].cpu()))) for j, p in enumerate(zip(*self.paramslist))) 83 | else: # layer-wise merging 84 | params = tuple(sum(tuple(pi * alphai for pi, alphai in zip(p, self.alpha[j].cpu()))) for j, p in enumerate(zip(*self.paramslist))) 85 | ``` 86 | More details can be found in the following code: https://github.com/EnnengYang/RepresentationSurgery 87 | 88 | 89 | 90 | ## Acknowledgement 91 | Our implementation references the code below, thanks to them. 92 | 93 | - Task Arithmetic: https://github.com/mlfoundations/task_vectors 94 | 95 | - TIES-MERGING: https://github.com/prateeky2806/ties-merging/tree/main 96 | 97 | - Model Soups: https://github.com/mlfoundations/model-soups 98 | 99 | - Tent: https://github.com/DequanWang/tent 100 | 101 | 102 | 103 | -------------------------------------------------------------------------------- /checkpoints/README.md: -------------------------------------------------------------------------------- 1 | Place checkpoints 2 | -------------------------------------------------------------------------------- /data/README.md: -------------------------------------------------------------------------------- 1 | Place datasets 2 | -------------------------------------------------------------------------------- /logs/ViT-B-32/log_20230921_202228_Task_wise_AdaMerging.txt: -------------------------------------------------------------------------------- 1 | [tensor([1.0000, 0.2901, 0.2902, 0.2907, 0.3009, 0.2926, 0.2923, 0.2902, 0.2904])] 2 | Eval: Epoch: 9 dataset: SUN397 ACC: 0.5622768642832001 3 | Eval: Epoch: 9 dataset: Cars ACC: 0.5585126228081084 4 | Eval: Epoch: 9 dataset: RESISC45 ACC: 0.6725396825396825 5 | Eval: Epoch: 9 dataset: EuroSAT ACC: 0.8048148148148148 6 | Eval: Epoch: 9 dataset: SVHN ACC: 0.7983635525507068 7 | Eval: Epoch: 9 dataset: GTSRB ACC: 0.6969912905779889 8 | Eval: Epoch: 9 dataset: MNIST ACC: 0.9713 9 | Eval: Epoch: 9 dataset: DTD ACC: 0.5079787234042553 10 | Eval: Epoch: 9 Avg ACC:0.6965971938723446 11 | 12 | [tensor([1.0000, 0.2805, 0.2808, 0.2830, 0.3047, 0.2870, 0.2844, 0.2812, 0.2810])] 13 | Eval: Epoch: 19 dataset: SUN397 ACC: 0.5705737416402675 14 | Eval: Epoch: 19 dataset: Cars ACC: 0.5647307548812337 15 | Eval: Epoch: 19 dataset: RESISC45 ACC: 0.6753968253968254 16 | Eval: Epoch: 19 dataset: EuroSAT ACC: 0.8218518518518518 17 | Eval: Epoch: 19 dataset: SVHN ACC: 0.795021511985249 18 | Eval: Epoch: 19 dataset: GTSRB ACC: 0.695249406175772 19 | Eval: Epoch: 19 dataset: MNIST ACC: 0.9687 20 | Eval: Epoch: 19 dataset: DTD ACC: 0.5122340425531915 21 | Eval: Epoch: 19 Avg ACC:0.7004697668105488 22 | 23 | [tensor([1.0000, 0.2716, 0.2722, 0.2760, 0.3090, 0.2827, 0.2787, 0.2729, 0.2721])] 24 | Eval: Epoch: 29 dataset: SUN397 ACC: 0.5773620958415046 25 | Eval: Epoch: 29 dataset: Cars ACC: 0.5698296231811963 26 | Eval: Epoch: 29 dataset: RESISC45 ACC: 0.6780952380952381 27 | Eval: Epoch: 29 dataset: EuroSAT ACC: 0.8344444444444444 28 | Eval: Epoch: 29 dataset: SVHN ACC: 0.7926782421634911 29 | Eval: Epoch: 29 dataset: GTSRB ACC: 0.6931908155186065 30 | Eval: Epoch: 29 dataset: MNIST ACC: 0.9668 31 | Eval: Epoch: 29 dataset: DTD ACC: 0.5132978723404256 32 | Eval: Epoch: 29 Avg ACC:0.7032122914481133 33 | 34 | [tensor([1.0000, 0.2634, 0.2642, 0.2689, 0.3120, 0.2806, 0.2740, 0.2655, 0.2637])] 35 | Eval: Epoch: 39 dataset: SUN397 ACC: 0.5817870970986071 36 | Eval: Epoch: 39 dataset: Cars ACC: 0.5730630518592215 37 | Eval: Epoch: 39 dataset: RESISC45 ACC: 0.6796825396825397 38 | Eval: Epoch: 39 dataset: EuroSAT ACC: 0.8466666666666667 39 | Eval: Epoch: 39 dataset: SVHN ACC: 0.7926782421634911 40 | Eval: Epoch: 39 dataset: GTSRB ACC: 0.6938242280285035 41 | Eval: Epoch: 39 dataset: MNIST ACC: 0.9648 42 | Eval: Epoch: 39 dataset: DTD ACC: 0.5138297872340426 43 | Eval: Epoch: 39 Avg ACC:0.705791451591634 44 | 45 | [tensor([1.0000, 0.2565, 0.2565, 0.2627, 0.3134, 0.2777, 0.2702, 0.2586, 0.2565])] 46 | Eval: Epoch: 49 dataset: SUN397 ACC: 0.5861618142505154 47 | Eval: Epoch: 49 dataset: Cars ACC: 0.575425942047009 48 | Eval: Epoch: 49 dataset: RESISC45 ACC: 0.6817460317460318 49 | Eval: Epoch: 49 dataset: EuroSAT ACC: 0.8555555555555555 50 | Eval: Epoch: 49 dataset: SVHN ACC: 0.7913721573448064 51 | Eval: Epoch: 49 dataset: GTSRB ACC: 0.6942992874109264 52 | Eval: Epoch: 49 dataset: MNIST ACC: 0.9618 53 | Eval: Epoch: 49 dataset: DTD ACC: 0.5117021276595745 54 | Eval: Epoch: 49 Avg ACC:0.7072578645018024 55 | 56 | [tensor([1.0000, 0.2499, 0.2495, 0.2588, 0.3129, 0.2748, 0.2689, 0.2522, 0.2500])] 57 | Eval: Epoch: 59 dataset: SUN397 ACC: 0.5901342585608689 58 | Eval: Epoch: 59 dataset: Cars ACC: 0.5767939311030966 59 | Eval: Epoch: 59 dataset: RESISC45 ACC: 0.6846031746031747 60 | Eval: Epoch: 59 dataset: EuroSAT ACC: 0.8596296296296296 61 | Eval: Epoch: 59 dataset: SVHN ACC: 0.7905654578979717 62 | Eval: Epoch: 59 dataset: GTSRB ACC: 0.6953285827395091 63 | Eval: Epoch: 59 dataset: MNIST ACC: 0.9596 64 | Eval: Epoch: 59 dataset: DTD ACC: 0.5122340425531915 65 | Eval: Epoch: 59 Avg ACC:0.7086111346359303 66 | 67 | [tensor([1.0000, 0.2444, 0.2435, 0.2557, 0.3117, 0.2748, 0.2691, 0.2476, 0.2439])] 68 | Eval: Epoch: 69 dataset: SUN397 ACC: 0.5918942022426711 69 | Eval: Epoch: 69 dataset: Cars ACC: 0.5786593707250342 70 | Eval: Epoch: 69 dataset: RESISC45 ACC: 0.6857142857142857 71 | Eval: Epoch: 69 dataset: EuroSAT ACC: 0.8611111111111112 72 | Eval: Epoch: 69 dataset: SVHN ACC: 0.7920251997541488 73 | Eval: Epoch: 69 dataset: GTSRB ACC: 0.6988915281076801 74 | Eval: Epoch: 69 dataset: MNIST ACC: 0.9584 75 | Eval: Epoch: 69 dataset: DTD ACC: 0.5090425531914894 76 | Eval: Epoch: 69 Avg ACC:0.7094672813558025 77 | 78 | [tensor([1.0000, 0.2391, 0.2377, 0.2537, 0.3122, 0.2764, 0.2725, 0.2450, 0.2385])] 79 | Eval: Epoch: 79 dataset: SUN397 ACC: 0.5919947704530598 80 | Eval: Epoch: 79 dataset: Cars ACC: 0.5775401069518716 81 | Eval: Epoch: 79 dataset: RESISC45 ACC: 0.6850793650793651 82 | Eval: Epoch: 79 dataset: EuroSAT ACC: 0.8625925925925926 83 | Eval: Epoch: 79 dataset: SVHN ACC: 0.7949446834665027 84 | Eval: Epoch: 79 dataset: GTSRB ACC: 0.705938242280285 85 | Eval: Epoch: 79 dataset: MNIST ACC: 0.9577 86 | Eval: Epoch: 79 dataset: DTD ACC: 0.5074468085106383 87 | Eval: Epoch: 79 Avg ACC:0.7104045711667895 88 | 89 | [tensor([1.0000, 0.2350, 0.2321, 0.2503, 0.3123, 0.2803, 0.2765, 0.2433, 0.2336])] 90 | Eval: Epoch: 89 dataset: SUN397 ACC: 0.5916930658218937 91 | Eval: Epoch: 89 dataset: Cars ACC: 0.5760477552543216 92 | Eval: Epoch: 89 dataset: RESISC45 ACC: 0.6823809523809524 93 | Eval: Epoch: 89 dataset: EuroSAT ACC: 0.8618518518518519 94 | Eval: Epoch: 89 dataset: SVHN ACC: 0.7998617086662569 95 | Eval: Epoch: 89 dataset: GTSRB ACC: 0.7125890736342043 96 | Eval: Epoch: 89 dataset: MNIST ACC: 0.9572 97 | Eval: Epoch: 89 dataset: DTD ACC: 0.5031914893617021 98 | Eval: Epoch: 89 Avg ACC:0.7106019871213979 99 | 100 | [tensor([1.0000, 0.2311, 0.2270, 0.2478, 0.3134, 0.2835, 0.2806, 0.2420, 0.2294])] 101 | Eval: Epoch: 99 dataset: SUN397 ACC: 0.5907376678232011 102 | Eval: Epoch: 99 dataset: Cars ACC: 0.5733117771421465 103 | Eval: Epoch: 99 dataset: RESISC45 ACC: 0.680952380952381 104 | Eval: Epoch: 99 dataset: EuroSAT ACC: 0.8607407407407407 105 | Eval: Epoch: 99 dataset: SVHN ACC: 0.8045482483097726 106 | Eval: Epoch: 99 dataset: GTSRB ACC: 0.7191607284243864 107 | Eval: Epoch: 99 dataset: MNIST ACC: 0.9567 108 | Eval: Epoch: 99 dataset: DTD ACC: 0.5015957446808511 109 | Eval: Epoch: 99 Avg ACC:0.7109684110091848 110 | 111 | [tensor([1.0000, 0.2279, 0.2223, 0.2460, 0.3147, 0.2856, 0.2859, 0.2405, 0.2260])] 112 | Eval: Epoch: 109 dataset: SUN397 ACC: 0.590285110876452 113 | Eval: Epoch: 109 dataset: Cars ACC: 0.5715707001616714 114 | Eval: Epoch: 109 dataset: RESISC45 ACC: 0.6785714285714286 115 | Eval: Epoch: 109 dataset: EuroSAT ACC: 0.8618518518518519 116 | Eval: Epoch: 109 dataset: SVHN ACC: 0.8073140749846343 117 | Eval: Epoch: 109 dataset: GTSRB ACC: 0.7269200316706255 118 | Eval: Epoch: 109 dataset: MNIST ACC: 0.9556 119 | Eval: Epoch: 109 dataset: DTD ACC: 0.5 120 | Eval: Epoch: 109 Avg ACC:0.7115141497645829 121 | 122 | [tensor([1.0000, 0.2259, 0.2190, 0.2443, 0.3152, 0.2878, 0.2920, 0.2400, 0.2241])] 123 | Eval: Epoch: 119 dataset: SUN397 ACC: 0.5889274400362046 124 | Eval: Epoch: 119 dataset: Cars ACC: 0.5694565352568088 125 | Eval: Epoch: 119 dataset: RESISC45 ACC: 0.6771428571428572 126 | Eval: Epoch: 119 dataset: EuroSAT ACC: 0.8607407407407407 127 | Eval: Epoch: 119 dataset: SVHN ACC: 0.8103103872157345 128 | Eval: Epoch: 119 dataset: GTSRB ACC: 0.7344418052256532 129 | Eval: Epoch: 119 dataset: MNIST ACC: 0.9552 130 | Eval: Epoch: 119 dataset: DTD ACC: 0.49627659574468086 131 | Eval: Epoch: 119 Avg ACC:0.711562045170335 132 | 133 | [tensor([1.0000, 0.2237, 0.2146, 0.2460, 0.3189, 0.2891, 0.2978, 0.2375, 0.2221])] 134 | Eval: Epoch: 129 dataset: SUN397 ACC: 0.5877206215115403 135 | Eval: Epoch: 129 dataset: Cars ACC: 0.5668449197860963 136 | Eval: Epoch: 129 dataset: RESISC45 ACC: 0.6757142857142857 137 | Eval: Epoch: 129 dataset: EuroSAT ACC: 0.8633333333333333 138 | Eval: Epoch: 129 dataset: SVHN ACC: 0.8115780577750461 139 | Eval: Epoch: 129 dataset: GTSRB ACC: 0.741409342834521 140 | Eval: Epoch: 129 dataset: MNIST ACC: 0.9541 141 | Eval: Epoch: 129 dataset: DTD ACC: 0.4946808510638298 142 | Eval: Epoch: 129 Avg ACC:0.7119226765023314 143 | 144 | [tensor([1.0000, 0.2223, 0.2108, 0.2463, 0.3256, 0.2921, 0.3023, 0.2342, 0.2199])] 145 | Eval: Epoch: 139 dataset: SUN397 ACC: 0.5865138029868758 146 | Eval: Epoch: 139 dataset: Cars ACC: 0.5651038428056212 147 | Eval: Epoch: 139 dataset: RESISC45 ACC: 0.6726984126984127 148 | Eval: Epoch: 139 dataset: EuroSAT ACC: 0.8659259259259259 149 | Eval: Epoch: 139 dataset: SVHN ACC: 0.8131530424093424 150 | Eval: Epoch: 139 dataset: GTSRB ACC: 0.7454473475851148 151 | Eval: Epoch: 139 dataset: MNIST ACC: 0.9517 152 | Eval: Epoch: 139 dataset: DTD ACC: 0.49148936170212765 153 | Eval: Epoch: 139 Avg ACC:0.7115039670141775 154 | 155 | [tensor([1.0000, 0.2210, 0.2071, 0.2469, 0.3293, 0.2918, 0.3060, 0.2291, 0.2173])] 156 | Eval: Epoch: 149 dataset: SUN397 ACC: 0.5866143711972646 157 | Eval: Epoch: 149 dataset: Cars ACC: 0.5639845790324587 158 | Eval: Epoch: 149 dataset: RESISC45 ACC: 0.6733333333333333 159 | Eval: Epoch: 149 dataset: EuroSAT ACC: 0.8685185185185185 160 | Eval: Epoch: 149 dataset: SVHN ACC: 0.8123079287031346 161 | Eval: Epoch: 149 dataset: GTSRB ACC: 0.7513064133016627 162 | Eval: Epoch: 149 dataset: MNIST ACC: 0.9493 163 | Eval: Epoch: 149 dataset: DTD ACC: 0.48829787234042554 164 | Eval: Epoch: 149 Avg ACC:0.7117078770533498 165 | 166 | [tensor([1.0000, 0.2190, 0.2033, 0.2497, 0.3315, 0.2897, 0.3085, 0.2263, 0.2147])] 167 | Eval: Epoch: 159 dataset: SUN397 ACC: 0.5871674963544024 168 | Eval: Epoch: 159 dataset: Cars ACC: 0.5637358537495336 169 | Eval: Epoch: 159 dataset: RESISC45 ACC: 0.6763492063492064 170 | Eval: Epoch: 159 dataset: EuroSAT ACC: 0.8729629629629629 171 | Eval: Epoch: 159 dataset: SVHN ACC: 0.8099262446220037 172 | Eval: Epoch: 159 dataset: GTSRB ACC: 0.7548693586698337 173 | Eval: Epoch: 159 dataset: MNIST ACC: 0.9475 174 | Eval: Epoch: 159 dataset: DTD ACC: 0.48829787234042554 175 | Eval: Epoch: 159 Avg ACC:0.7126011243810461 176 | 177 | [tensor([1.0000, 0.2168, 0.2003, 0.2509, 0.3316, 0.2899, 0.3136, 0.2241, 0.2125])] 178 | Eval: Epoch: 169 dataset: SUN397 ACC: 0.5857092573037663 179 | Eval: Epoch: 169 dataset: Cars ACC: 0.562119139410521 180 | Eval: Epoch: 169 dataset: RESISC45 ACC: 0.6768253968253968 181 | Eval: Epoch: 169 dataset: EuroSAT ACC: 0.8725925925925926 182 | Eval: Epoch: 169 dataset: SVHN ACC: 0.8104256299938537 183 | Eval: Epoch: 169 dataset: GTSRB ACC: 0.7622327790973872 184 | Eval: Epoch: 169 dataset: MNIST ACC: 0.9461 185 | Eval: Epoch: 169 dataset: DTD ACC: 0.48723404255319147 186 | Eval: Epoch: 169 Avg ACC:0.7129048547220886 187 | 188 | [tensor([1.0000, 0.2165, 0.1982, 0.2516, 0.3297, 0.2904, 0.3178, 0.2235, 0.2115])] 189 | Eval: Epoch: 179 dataset: SUN397 ACC: 0.5854075526726001 190 | Eval: Epoch: 179 dataset: Cars ACC: 0.5603780624300461 191 | Eval: Epoch: 179 dataset: RESISC45 ACC: 0.6765079365079365 192 | Eval: Epoch: 179 dataset: EuroSAT ACC: 0.8696296296296296 193 | Eval: Epoch: 179 dataset: SVHN ACC: 0.8113475722188076 194 | Eval: Epoch: 179 dataset: GTSRB ACC: 0.7672209026128266 195 | Eval: Epoch: 179 dataset: MNIST ACC: 0.9458 196 | Eval: Epoch: 179 dataset: DTD ACC: 0.48723404255319147 197 | Eval: Epoch: 179 Avg ACC:0.7129407123281297 198 | 199 | [tensor([1.0000, 0.2144, 0.1972, 0.2519, 0.3271, 0.2915, 0.3227, 0.2221, 0.2098])] 200 | Eval: Epoch: 189 dataset: SUN397 ACC: 0.5843515864635189 201 | Eval: Epoch: 189 dataset: Cars ACC: 0.5592587986568834 202 | Eval: Epoch: 189 dataset: RESISC45 ACC: 0.676031746031746 203 | Eval: Epoch: 189 dataset: EuroSAT ACC: 0.8662962962962963 204 | Eval: Epoch: 189 dataset: SVHN ACC: 0.81399815611555 205 | Eval: Epoch: 189 dataset: GTSRB ACC: 0.7746634996041172 206 | Eval: Epoch: 189 dataset: MNIST ACC: 0.945 207 | Eval: Epoch: 189 dataset: DTD ACC: 0.48723404255319147 208 | Eval: Epoch: 189 Avg ACC:0.7133542657151629 209 | 210 | [tensor([1.0000, 0.2125, 0.1959, 0.2532, 0.3272, 0.2929, 0.3287, 0.2200, 0.2088])] 211 | Eval: Epoch: 199 dataset: SUN397 ACC: 0.5827927792024941 212 | Eval: Epoch: 199 dataset: Cars ACC: 0.5572689963934834 213 | Eval: Epoch: 199 dataset: RESISC45 ACC: 0.6765079365079365 214 | Eval: Epoch: 199 dataset: EuroSAT ACC: 0.8659259259259259 215 | Eval: Epoch: 199 dataset: SVHN ACC: 0.8161493546404426 216 | Eval: Epoch: 199 dataset: GTSRB ACC: 0.780443388756928 217 | Eval: Epoch: 199 dataset: MNIST ACC: 0.9435 218 | Eval: Epoch: 199 dataset: DTD ACC: 0.4856382978723404 219 | Eval: Epoch: 199 Avg ACC:0.7135283349124439 220 | 221 | [tensor([1.0000, 0.2113, 0.1929, 0.2542, 0.3276, 0.2950, 0.3328, 0.2176, 0.2084])] 222 | Eval: Epoch: 209 dataset: SUN397 ACC: 0.5821390858349675 223 | Eval: Epoch: 209 dataset: Cars ACC: 0.5565228205447084 224 | Eval: Epoch: 209 dataset: RESISC45 ACC: 0.676031746031746 225 | Eval: Epoch: 209 dataset: EuroSAT ACC: 0.8651851851851852 226 | Eval: Epoch: 209 dataset: SVHN ACC: 0.8174938537185003 227 | Eval: Epoch: 209 dataset: GTSRB ACC: 0.7844022169437846 228 | Eval: Epoch: 209 dataset: MNIST ACC: 0.942 229 | Eval: Epoch: 209 dataset: DTD ACC: 0.4851063829787234 230 | Eval: Epoch: 209 Avg ACC:0.7136101614047019 231 | 232 | [tensor([1.0000, 0.2099, 0.1892, 0.2544, 0.3261, 0.2960, 0.3377, 0.2155, 0.2074])] 233 | Eval: Epoch: 219 dataset: SUN397 ACC: 0.5813848242570523 234 | Eval: Epoch: 219 dataset: Cars ACC: 0.5551548314886208 235 | Eval: Epoch: 219 dataset: RESISC45 ACC: 0.6768253968253968 236 | Eval: Epoch: 219 dataset: EuroSAT ACC: 0.8622222222222222 237 | Eval: Epoch: 219 dataset: SVHN ACC: 0.8196834665027658 238 | Eval: Epoch: 219 dataset: GTSRB ACC: 0.789944576405384 239 | Eval: Epoch: 219 dataset: MNIST ACC: 0.941 240 | Eval: Epoch: 219 dataset: DTD ACC: 0.48351063829787233 241 | Eval: Epoch: 219 Avg ACC:0.7137157444999143 242 | 243 | [tensor([1.0000, 0.2098, 0.1874, 0.2561, 0.3247, 0.2958, 0.3429, 0.2145, 0.2068])] 244 | Eval: Epoch: 229 dataset: SUN397 ACC: 0.5807814149947201 245 | Eval: Epoch: 229 dataset: Cars ACC: 0.5539112050739958 246 | Eval: Epoch: 229 dataset: RESISC45 ACC: 0.677936507936508 247 | Eval: Epoch: 229 dataset: EuroSAT ACC: 0.8603703703703703 248 | Eval: Epoch: 229 dataset: SVHN ACC: 0.8201444376152428 249 | Eval: Epoch: 229 dataset: GTSRB ACC: 0.795882818685669 250 | Eval: Epoch: 229 dataset: MNIST ACC: 0.9403 251 | Eval: Epoch: 229 dataset: DTD ACC: 0.4803191489361702 252 | Eval: Epoch: 229 Avg ACC:0.7137057379515845 253 | 254 | [tensor([1.0000, 0.2103, 0.1853, 0.2589, 0.3245, 0.2945, 0.3494, 0.2131, 0.2062])] 255 | Eval: Epoch: 239 dataset: SUN397 ACC: 0.5794740282596671 256 | Eval: Epoch: 239 dataset: Cars ACC: 0.5519214028105957 257 | Eval: Epoch: 239 dataset: RESISC45 ACC: 0.6788888888888889 258 | Eval: Epoch: 239 dataset: EuroSAT ACC: 0.8596296296296296 259 | Eval: Epoch: 239 dataset: SVHN ACC: 0.819299323909035 260 | Eval: Epoch: 239 dataset: GTSRB ACC: 0.801187648456057 261 | Eval: Epoch: 239 dataset: MNIST ACC: 0.9383 262 | Eval: Epoch: 239 dataset: DTD ACC: 0.4776595744680851 263 | Eval: Epoch: 239 Avg ACC:0.7132950620527447 264 | 265 | [tensor([1.0000, 0.2109, 0.1836, 0.2605, 0.3229, 0.2918, 0.3555, 0.2109, 0.2054])] 266 | Eval: Epoch: 249 dataset: SUN397 ACC: 0.5795243123648615 267 | Eval: Epoch: 249 dataset: Cars ACC: 0.5511752269618206 268 | Eval: Epoch: 249 dataset: RESISC45 ACC: 0.6803174603174603 269 | Eval: Epoch: 249 dataset: EuroSAT ACC: 0.8588888888888889 270 | Eval: Epoch: 249 dataset: SVHN ACC: 0.8172249539028887 271 | Eval: Epoch: 249 dataset: GTSRB ACC: 0.8079176563737134 272 | Eval: Epoch: 249 dataset: MNIST ACC: 0.9366 273 | Eval: Epoch: 249 dataset: DTD ACC: 0.47606382978723405 274 | Eval: Epoch: 249 Avg ACC:0.7134640410746085 275 | 276 | [tensor([1.0000, 0.2100, 0.1804, 0.2611, 0.3224, 0.2915, 0.3598, 0.2104, 0.2035])] 277 | Eval: Epoch: 259 dataset: SUN397 ACC: 0.5788203348921406 278 | Eval: Epoch: 259 dataset: Cars ACC: 0.5504290511130456 279 | Eval: Epoch: 259 dataset: RESISC45 ACC: 0.6807936507936508 280 | Eval: Epoch: 259 dataset: EuroSAT ACC: 0.8577777777777778 281 | Eval: Epoch: 259 dataset: SVHN ACC: 0.817839582052858 282 | Eval: Epoch: 259 dataset: GTSRB ACC: 0.8131433095803642 283 | Eval: Epoch: 259 dataset: MNIST ACC: 0.9363 284 | Eval: Epoch: 259 dataset: DTD ACC: 0.474468085106383 285 | Eval: Epoch: 259 Avg ACC:0.7136964739145275 286 | 287 | [tensor([1.0000, 0.2102, 0.1765, 0.2594, 0.3206, 0.2938, 0.3637, 0.2105, 0.2014])] 288 | Eval: Epoch: 269 dataset: SUN397 ACC: 0.5783677779453915 289 | Eval: Epoch: 269 dataset: Cars ACC: 0.5488123367740331 290 | Eval: Epoch: 269 dataset: RESISC45 ACC: 0.6795238095238095 291 | Eval: Epoch: 269 dataset: EuroSAT ACC: 0.8522222222222222 292 | Eval: Epoch: 269 dataset: SVHN ACC: 0.821220036877689 293 | Eval: Epoch: 269 dataset: GTSRB ACC: 0.8174980205859066 294 | Eval: Epoch: 269 dataset: MNIST ACC: 0.9363 295 | Eval: Epoch: 269 dataset: DTD ACC: 0.4728723404255319 296 | Eval: Epoch: 269 Avg ACC:0.713352068044323 297 | 298 | [tensor([1.0000, 0.2110, 0.1739, 0.2590, 0.3193, 0.2976, 0.3662, 0.2108, 0.2005])] 299 | Eval: Epoch: 279 dataset: SUN397 ACC: 0.5775632322622819 300 | Eval: Epoch: 279 dataset: Cars ACC: 0.546325083944783 301 | Eval: Epoch: 279 dataset: RESISC45 ACC: 0.6776190476190476 302 | Eval: Epoch: 279 dataset: EuroSAT ACC: 0.8466666666666667 303 | Eval: Epoch: 279 dataset: SVHN ACC: 0.8248309772587584 304 | Eval: Epoch: 279 dataset: GTSRB ACC: 0.8192399049881235 305 | Eval: Epoch: 279 dataset: MNIST ACC: 0.9366 306 | Eval: Epoch: 279 dataset: DTD ACC: 0.47074468085106386 307 | Eval: Epoch: 279 Avg ACC:0.7124486991988407 308 | 309 | [tensor([1.0000, 0.2103, 0.1713, 0.2608, 0.3216, 0.2974, 0.3687, 0.2081, 0.1986])] 310 | Eval: Epoch: 289 dataset: SUN397 ACC: 0.5773118117363102 311 | Eval: Epoch: 289 dataset: Cars ACC: 0.545330182813083 312 | Eval: Epoch: 289 dataset: RESISC45 ACC: 0.6792063492063493 313 | Eval: Epoch: 289 dataset: EuroSAT ACC: 0.85 314 | Eval: Epoch: 289 dataset: SVHN ACC: 0.8242547633681623 315 | Eval: Epoch: 289 dataset: GTSRB ACC: 0.8209817893903405 316 | Eval: Epoch: 289 dataset: MNIST ACC: 0.9344 317 | Eval: Epoch: 289 dataset: DTD ACC: 0.46702127659574466 318 | Eval: Epoch: 289 Avg ACC:0.7123132716387488 319 | 320 | [tensor([1.0000, 0.2105, 0.1686, 0.2638, 0.3234, 0.2978, 0.3720, 0.2070, 0.1964])] 321 | Eval: Epoch: 299 dataset: SUN397 ACC: 0.5765072660532006 322 | Eval: Epoch: 299 dataset: Cars ACC: 0.5432160179082204 323 | Eval: Epoch: 299 dataset: RESISC45 ACC: 0.6795238095238095 324 | Eval: Epoch: 299 dataset: EuroSAT ACC: 0.8514814814814815 325 | Eval: Epoch: 299 dataset: SVHN ACC: 0.8242931776275353 326 | Eval: Epoch: 299 dataset: GTSRB ACC: 0.8230403800475059 327 | Eval: Epoch: 299 dataset: MNIST ACC: 0.9336 328 | Eval: Epoch: 299 dataset: DTD ACC: 0.4664893617021277 329 | Eval: Epoch: 299 Avg ACC:0.7122689367929852 330 | 331 | [tensor([1.0000, 0.2094, 0.1661, 0.2631, 0.3229, 0.2994, 0.3772, 0.2078, 0.1944])] 332 | Eval: Epoch: 309 dataset: SUN397 ACC: 0.5753507316337306 333 | Eval: Epoch: 309 dataset: Cars ACC: 0.5414749409277453 334 | Eval: Epoch: 309 dataset: RESISC45 ACC: 0.677936507936508 335 | Eval: Epoch: 309 dataset: EuroSAT ACC: 0.847037037037037 336 | Eval: Epoch: 309 dataset: SVHN ACC: 0.8262907191149355 337 | Eval: Epoch: 309 dataset: GTSRB ACC: 0.8274742676167854 338 | Eval: Epoch: 309 dataset: MNIST ACC: 0.9337 339 | Eval: Epoch: 309 dataset: DTD ACC: 0.46117021276595743 340 | Eval: Epoch: 309 Avg ACC:0.7113043021290875 341 | 342 | [tensor([1.0000, 0.2081, 0.1654, 0.2625, 0.3215, 0.3005, 0.3838, 0.2073, 0.1946])] 343 | Eval: Epoch: 319 dataset: SUN397 ACC: 0.5737416402675114 344 | Eval: Epoch: 319 dataset: Cars ACC: 0.5392364133814203 345 | Eval: Epoch: 319 dataset: RESISC45 ACC: 0.6758730158730158 346 | Eval: Epoch: 319 dataset: EuroSAT ACC: 0.8425925925925926 347 | Eval: Epoch: 319 dataset: SVHN ACC: 0.8275583896742471 348 | Eval: Epoch: 319 dataset: GTSRB ACC: 0.8321456848772764 349 | Eval: Epoch: 319 dataset: MNIST ACC: 0.933 350 | Eval: Epoch: 319 dataset: DTD ACC: 0.4622340425531915 351 | Eval: Epoch: 319 Avg ACC:0.7107977224024069 352 | 353 | [tensor([1.0000, 0.2067, 0.1635, 0.2626, 0.3215, 0.3016, 0.3885, 0.2079, 0.1931])] 354 | Eval: Epoch: 329 dataset: SUN397 ACC: 0.5716799919545432 355 | Eval: Epoch: 329 dataset: Cars ACC: 0.5381171496082576 356 | Eval: Epoch: 329 dataset: RESISC45 ACC: 0.6749206349206349 357 | Eval: Epoch: 329 dataset: EuroSAT ACC: 0.84 358 | Eval: Epoch: 329 dataset: SVHN ACC: 0.8284035033804549 359 | Eval: Epoch: 329 dataset: GTSRB ACC: 0.835550277117973 360 | Eval: Epoch: 329 dataset: MNIST ACC: 0.9332 361 | Eval: Epoch: 329 dataset: DTD ACC: 0.4622340425531915 362 | Eval: Epoch: 329 Avg ACC:0.710513199941882 363 | 364 | [tensor([1.0000, 0.2057, 0.1604, 0.2641, 0.3229, 0.2967, 0.3910, 0.2061, 0.1897])] 365 | Eval: Epoch: 339 dataset: SUN397 ACC: 0.572383969427264 366 | Eval: Epoch: 339 dataset: Cars ACC: 0.5376196990424077 367 | Eval: Epoch: 339 dataset: RESISC45 ACC: 0.6765079365079365 368 | Eval: Epoch: 339 dataset: EuroSAT ACC: 0.8433333333333334 369 | Eval: Epoch: 339 dataset: SVHN ACC: 0.8253687768899816 370 | Eval: Epoch: 339 dataset: GTSRB ACC: 0.8382422802850357 371 | Eval: Epoch: 339 dataset: MNIST ACC: 0.9313 372 | Eval: Epoch: 339 dataset: DTD ACC: 0.45797872340425533 373 | Eval: Epoch: 339 Avg ACC:0.7103418398612767 374 | 375 | [tensor([1.0000, 0.2054, 0.1573, 0.2648, 0.3239, 0.2931, 0.3932, 0.2041, 0.1868])] 376 | Eval: Epoch: 349 dataset: SUN397 ACC: 0.573540503846734 377 | Eval: Epoch: 349 dataset: Cars ACC: 0.5381171496082576 378 | Eval: Epoch: 349 dataset: RESISC45 ACC: 0.6773015873015873 379 | Eval: Epoch: 349 dataset: EuroSAT ACC: 0.8455555555555555 380 | Eval: Epoch: 349 dataset: SVHN ACC: 0.822180393362016 381 | Eval: Epoch: 349 dataset: GTSRB ACC: 0.8398258115597783 382 | Eval: Epoch: 349 dataset: MNIST ACC: 0.9298 383 | Eval: Epoch: 349 dataset: DTD ACC: 0.45478723404255317 384 | Eval: Epoch: 349 Avg ACC:0.7101385294095602 385 | 386 | [tensor([1.0000, 0.2083, 0.1568, 0.2709, 0.3245, 0.2906, 0.3951, 0.2031, 0.1869])] 387 | Eval: Epoch: 359 dataset: SUN397 ACC: 0.5740433448986775 388 | Eval: Epoch: 359 dataset: Cars ACC: 0.5373709737594826 389 | Eval: Epoch: 359 dataset: RESISC45 ACC: 0.68 390 | Eval: Epoch: 359 dataset: EuroSAT ACC: 0.847037037037037 391 | Eval: Epoch: 359 dataset: SVHN ACC: 0.8187231100184389 392 | Eval: Epoch: 359 dataset: GTSRB ACC: 0.8408551068883611 393 | Eval: Epoch: 359 dataset: MNIST ACC: 0.9289 394 | Eval: Epoch: 359 dataset: DTD ACC: 0.4531914893617021 395 | Eval: Epoch: 359 Avg ACC:0.7100151327454624 396 | 397 | [tensor([1.0000, 0.2097, 0.1568, 0.2737, 0.3251, 0.2885, 0.3968, 0.2038, 0.1862])] 398 | Eval: Epoch: 369 dataset: SUN397 ACC: 0.5736410720571228 399 | Eval: Epoch: 369 dataset: Cars ACC: 0.5367491605521701 400 | Eval: Epoch: 369 dataset: RESISC45 ACC: 0.6804761904761905 401 | Eval: Epoch: 369 dataset: EuroSAT ACC: 0.8466666666666667 402 | Eval: Epoch: 369 dataset: SVHN ACC: 0.816379840196681 403 | Eval: Epoch: 369 dataset: GTSRB ACC: 0.841409342834521 404 | Eval: Epoch: 369 dataset: MNIST ACC: 0.9287 405 | Eval: Epoch: 369 dataset: DTD ACC: 0.4531914893617021 406 | Eval: Epoch: 369 Avg ACC:0.7096517202681317 407 | 408 | [tensor([1.0000, 0.2091, 0.1553, 0.2723, 0.3230, 0.2890, 0.4026, 0.2043, 0.1841])] 409 | Eval: Epoch: 379 dataset: SUN397 ACC: 0.5726856740584302 410 | Eval: Epoch: 379 dataset: Cars ACC: 0.5352568088546201 411 | Eval: Epoch: 379 dataset: RESISC45 ACC: 0.6801587301587302 412 | Eval: Epoch: 379 dataset: EuroSAT ACC: 0.8433333333333334 413 | Eval: Epoch: 379 dataset: SVHN ACC: 0.8182237246465888 414 | Eval: Epoch: 379 dataset: GTSRB ACC: 0.8460807600950119 415 | Eval: Epoch: 379 dataset: MNIST ACC: 0.9286 416 | Eval: Epoch: 379 dataset: DTD ACC: 0.4526595744680851 417 | Eval: Epoch: 379 Avg ACC:0.70962482570185 418 | 419 | [tensor([1.0000, 0.2077, 0.1540, 0.2692, 0.3209, 0.2916, 0.4059, 0.2037, 0.1824])] 420 | Eval: Epoch: 389 dataset: SUN397 ACC: 0.5717805601649318 421 | Eval: Epoch: 389 dataset: Cars ACC: 0.5346349956473075 422 | Eval: Epoch: 389 dataset: RESISC45 ACC: 0.6784126984126985 423 | Eval: Epoch: 389 dataset: EuroSAT ACC: 0.8392592592592593 424 | Eval: Epoch: 389 dataset: SVHN ACC: 0.8219883220651506 425 | Eval: Epoch: 389 dataset: GTSRB ACC: 0.8486935866983373 426 | Eval: Epoch: 389 dataset: MNIST ACC: 0.9286 427 | Eval: Epoch: 389 dataset: DTD ACC: 0.45159574468085106 428 | Eval: Epoch: 389 Avg ACC:0.7093706458660671 429 | 430 | [tensor([1.0000, 0.2075, 0.1515, 0.2692, 0.3206, 0.2927, 0.4076, 0.2034, 0.1795])] 431 | Eval: Epoch: 399 dataset: SUN397 ACC: 0.5719816965857093 432 | Eval: Epoch: 399 dataset: Cars ACC: 0.533018281308295 433 | Eval: Epoch: 399 dataset: RESISC45 ACC: 0.6785714285714286 434 | Eval: Epoch: 399 dataset: EuroSAT ACC: 0.8381481481481482 435 | Eval: Epoch: 399 dataset: SVHN ACC: 0.8229102642901045 436 | Eval: Epoch: 399 dataset: GTSRB ACC: 0.8501187648456057 437 | Eval: Epoch: 399 dataset: MNIST ACC: 0.9285 438 | Eval: Epoch: 399 dataset: DTD ACC: 0.451063829787234 439 | Eval: Epoch: 399 Avg ACC:0.7092890516920656 440 | 441 | [tensor([1.0000, 0.2086, 0.1492, 0.2702, 0.3197, 0.2921, 0.4047, 0.2030, 0.1794])] 442 | Eval: Epoch: 409 dataset: SUN397 ACC: 0.5729873786895963 443 | Eval: Epoch: 409 dataset: Cars ACC: 0.5331426439497575 444 | Eval: Epoch: 409 dataset: RESISC45 ACC: 0.6798412698412698 445 | Eval: Epoch: 409 dataset: EuroSAT ACC: 0.8381481481481482 446 | Eval: Epoch: 409 dataset: SVHN ACC: 0.8227181929932391 447 | Eval: Epoch: 409 dataset: GTSRB ACC: 0.8481393507521774 448 | Eval: Epoch: 409 dataset: MNIST ACC: 0.9286 449 | Eval: Epoch: 409 dataset: DTD ACC: 0.451063829787234 450 | Eval: Epoch: 409 Avg ACC:0.7093301017701779 451 | 452 | [tensor([1.0000, 0.2098, 0.1449, 0.2720, 0.3202, 0.2917, 0.4005, 0.2022, 0.1792])] 453 | Eval: Epoch: 419 dataset: SUN397 ACC: 0.5760044250012571 454 | Eval: Epoch: 419 dataset: Cars ACC: 0.533018281308295 455 | Eval: Epoch: 419 dataset: RESISC45 ACC: 0.6819047619047619 456 | Eval: Epoch: 419 dataset: EuroSAT ACC: 0.8414814814814815 457 | Eval: Epoch: 419 dataset: SVHN ACC: 0.8224108789182545 458 | Eval: Epoch: 419 dataset: GTSRB ACC: 0.8456848772763262 459 | Eval: Epoch: 419 dataset: MNIST ACC: 0.9283 460 | Eval: Epoch: 419 dataset: DTD ACC: 0.4521276595744681 461 | Eval: Epoch: 419 Avg ACC:0.7101165456831056 462 | 463 | [tensor([1.0000, 0.2116, 0.1421, 0.2741, 0.3199, 0.2931, 0.4009, 0.2016, 0.1778])] 464 | Eval: Epoch: 429 dataset: SUN397 ACC: 0.5763061296324232 465 | Eval: Epoch: 429 dataset: Cars ACC: 0.5317746548936699 466 | Eval: Epoch: 429 dataset: RESISC45 ACC: 0.6823809523809524 467 | Eval: Epoch: 429 dataset: EuroSAT ACC: 0.8414814814814815 468 | Eval: Epoch: 429 dataset: SVHN ACC: 0.8236017209588199 469 | Eval: Epoch: 429 dataset: GTSRB ACC: 0.8456848772763262 470 | Eval: Epoch: 429 dataset: MNIST ACC: 0.9283 471 | Eval: Epoch: 429 dataset: DTD ACC: 0.4521276595744681 472 | Eval: Epoch: 429 Avg ACC:0.7102071845247677 473 | 474 | [tensor([1.0000, 0.2136, 0.1415, 0.2752, 0.3196, 0.2965, 0.4043, 0.2038, 0.1771])] 475 | Eval: Epoch: 439 dataset: SUN397 ACC: 0.5743953336350379 476 | Eval: Epoch: 439 dataset: Cars ACC: 0.5287899514985698 477 | Eval: Epoch: 439 dataset: RESISC45 ACC: 0.6815873015873016 478 | Eval: Epoch: 439 dataset: EuroSAT ACC: 0.8351851851851851 479 | Eval: Epoch: 439 dataset: SVHN ACC: 0.8256760909649662 480 | Eval: Epoch: 439 dataset: GTSRB ACC: 0.8468725257323833 481 | Eval: Epoch: 439 dataset: MNIST ACC: 0.9296 482 | Eval: Epoch: 439 dataset: DTD ACC: 0.45053191489361705 483 | Eval: Epoch: 439 Avg ACC:0.7090797879371326 484 | 485 | [tensor([1.0000, 0.2134, 0.1402, 0.2763, 0.3215, 0.2935, 0.4036, 0.2002, 0.1741])] 486 | Eval: Epoch: 449 dataset: SUN397 ACC: 0.5767084024739779 487 | Eval: Epoch: 449 dataset: Cars ACC: 0.5306553911205074 488 | Eval: Epoch: 449 dataset: RESISC45 ACC: 0.6826984126984127 489 | Eval: Epoch: 449 dataset: EuroSAT ACC: 0.8422222222222222 490 | Eval: Epoch: 449 dataset: SVHN ACC: 0.8232944068838353 491 | Eval: Epoch: 449 dataset: GTSRB ACC: 0.8471100554235946 492 | Eval: Epoch: 449 dataset: MNIST ACC: 0.9272 493 | Eval: Epoch: 449 dataset: DTD ACC: 0.44840425531914896 494 | Eval: Epoch: 449 Avg ACC:0.7097866432677125 495 | 496 | [tensor([1.0000, 0.2137, 0.1394, 0.2744, 0.3232, 0.2922, 0.4023, 0.1989, 0.1706])] 497 | Eval: Epoch: 459 dataset: SUN397 ACC: 0.5781666415246141 498 | Eval: Epoch: 459 dataset: Cars ACC: 0.531525929610745 499 | Eval: Epoch: 459 dataset: RESISC45 ACC: 0.6826984126984127 500 | Eval: Epoch: 459 dataset: EuroSAT ACC: 0.8455555555555555 501 | Eval: Epoch: 459 dataset: SVHN ACC: 0.822180393362016 502 | Eval: Epoch: 459 dataset: GTSRB ACC: 0.846714172604909 503 | Eval: Epoch: 459 dataset: MNIST ACC: 0.9261 504 | Eval: Epoch: 459 dataset: DTD ACC: 0.44893617021276594 505 | Eval: Epoch: 459 Avg ACC:0.7102346594461273 506 | 507 | [tensor([1.0000, 0.2146, 0.1399, 0.2744, 0.3260, 0.2913, 0.4018, 0.1991, 0.1694])] 508 | Eval: Epoch: 469 dataset: SUN397 ACC: 0.5782672097350028 509 | Eval: Epoch: 469 dataset: Cars ACC: 0.5318990175351325 510 | Eval: Epoch: 469 dataset: RESISC45 ACC: 0.6825396825396826 511 | Eval: Epoch: 469 dataset: EuroSAT ACC: 0.8496296296296296 512 | Eval: Epoch: 469 dataset: SVHN ACC: 0.8200291948371236 513 | Eval: Epoch: 469 dataset: GTSRB ACC: 0.846159936658749 514 | Eval: Epoch: 469 dataset: MNIST ACC: 0.9262 515 | Eval: Epoch: 469 dataset: DTD ACC: 0.4478723404255319 516 | Eval: Epoch: 469 Avg ACC:0.7103246264201064 517 | 518 | [tensor([1.0000, 0.2164, 0.1411, 0.2774, 0.3262, 0.2905, 0.4019, 0.1989, 0.1706])] 519 | Eval: Epoch: 479 dataset: SUN397 ACC: 0.5784683461557801 520 | Eval: Epoch: 479 dataset: Cars ACC: 0.5317746548936699 521 | Eval: Epoch: 479 dataset: RESISC45 ACC: 0.6828571428571428 522 | Eval: Epoch: 479 dataset: EuroSAT ACC: 0.8496296296296296 523 | Eval: Epoch: 479 dataset: SVHN ACC: 0.818799938537185 524 | Eval: Epoch: 479 dataset: GTSRB ACC: 0.8458432304038005 525 | Eval: Epoch: 479 dataset: MNIST ACC: 0.9259 526 | Eval: Epoch: 479 dataset: DTD ACC: 0.44840425531914896 527 | Eval: Epoch: 479 Avg ACC:0.7102096497245447 528 | 529 | [tensor([1.0000, 0.2194, 0.1419, 0.2805, 0.3284, 0.2863, 0.4000, 0.1980, 0.1706])] 530 | Eval: Epoch: 489 dataset: SUN397 ACC: 0.5795243123648615 531 | Eval: Epoch: 489 dataset: Cars ACC: 0.5326451933839075 532 | Eval: Epoch: 489 dataset: RESISC45 ACC: 0.6861904761904762 533 | Eval: Epoch: 489 dataset: EuroSAT ACC: 0.8555555555555555 534 | Eval: Epoch: 489 dataset: SVHN ACC: 0.8139213275968039 535 | Eval: Epoch: 489 dataset: GTSRB ACC: 0.8444972288202692 536 | Eval: Epoch: 489 dataset: MNIST ACC: 0.925 537 | Eval: Epoch: 489 dataset: DTD ACC: 0.449468085106383 538 | Eval: Epoch: 489 Avg ACC:0.710850272377282 539 | 540 | [tensor([1.0000, 0.2202, 0.1413, 0.2826, 0.3284, 0.2841, 0.4003, 0.1978, 0.1692])] 541 | Eval: Epoch: 499 dataset: SUN397 ACC: 0.5802785739427767 542 | Eval: Epoch: 499 dataset: Cars ACC: 0.53227210545952 543 | Eval: Epoch: 499 dataset: RESISC45 ACC: 0.6880952380952381 544 | Eval: Epoch: 499 dataset: EuroSAT ACC: 0.857037037037037 545 | Eval: Epoch: 499 dataset: SVHN ACC: 0.8118853718500307 546 | Eval: Epoch: 499 dataset: GTSRB ACC: 0.8448931116389549 547 | Eval: Epoch: 499 dataset: MNIST ACC: 0.9244 548 | Eval: Epoch: 499 dataset: DTD ACC: 0.44840425531914896 549 | Eval: Epoch: 499 Avg ACC:0.7109082116678384 550 | -------------------------------------------------------------------------------- /logs/ViT-B-32/log_20230921_205130_Task_wise_AdaMergingPP.txt: -------------------------------------------------------------------------------- 1 | Eval: init: dataset: SUN397 ACC: 0.6506260371096696 2 | Eval: init: dataset: Cars ACC: 0.6443228454172366 3 | Eval: init: dataset: RESISC45 ACC: 0.7485714285714286 4 | Eval: init: dataset: EuroSAT ACC: 0.7733333333333333 5 | Eval: init: dataset: SVHN ACC: 0.8126536570374924 6 | Eval: init: dataset: GTSRB ACC: 0.6938242280285035 7 | Eval: init: dataset: MNIST ACC: 0.9652 8 | Eval: init: dataset: DTD ACC: 0.5452127659574468 9 | Eval: init: Avg ACC:0.7292180369318889 10 | 11 | [tensor([1.0000, 0.3099, 0.3098, 0.3097, 0.3097, 0.3097, 0.3098, 0.3098, 0.3100])] 12 | Eval: Epoch: 9 dataset: SUN397 ACC: 0.6471061497460653 13 | Eval: Epoch: 9 dataset: Cars ACC: 0.6422086805123741 14 | Eval: Epoch: 9 dataset: RESISC45 ACC: 0.7473015873015874 15 | Eval: Epoch: 9 dataset: EuroSAT ACC: 0.7762962962962963 16 | Eval: Epoch: 9 dataset: SVHN ACC: 0.8194529809465273 17 | Eval: Epoch: 9 dataset: GTSRB ACC: 0.6992082343626287 18 | Eval: Epoch: 9 dataset: MNIST ACC: 0.9673 19 | Eval: Epoch: 9 dataset: DTD ACC: 0.5425531914893617 20 | Eval: Epoch: 9 Avg ACC:0.7301783900818551 21 | 22 | [tensor([1.0000, 0.3193, 0.3193, 0.3192, 0.3195, 0.3188, 0.3194, 0.3187, 0.3195])] 23 | Eval: Epoch: 19 dataset: SUN397 ACC: 0.6437873988032383 24 | Eval: Epoch: 19 dataset: Cars ACC: 0.6387265265514239 25 | Eval: Epoch: 19 dataset: RESISC45 ACC: 0.7458730158730159 26 | Eval: Epoch: 19 dataset: EuroSAT ACC: 0.7785185185185185 27 | Eval: Epoch: 19 dataset: SVHN ACC: 0.8256376767055931 28 | Eval: Epoch: 19 dataset: GTSRB ACC: 0.7037212984956452 29 | Eval: Epoch: 19 dataset: MNIST ACC: 0.9697 30 | Eval: Epoch: 19 dataset: DTD ACC: 0.5452127659574468 31 | Eval: Epoch: 19 Avg ACC:0.7313971501131101 32 | 33 | [tensor([1.0000, 0.3276, 0.3273, 0.3280, 0.3287, 0.3275, 0.3284, 0.3267, 0.3285])] 34 | Eval: Epoch: 29 dataset: SUN397 ACC: 0.6409714889123548 35 | Eval: Epoch: 29 dataset: Cars ACC: 0.6376072627782614 36 | Eval: Epoch: 29 dataset: RESISC45 ACC: 0.746031746031746 37 | Eval: Epoch: 29 dataset: EuroSAT ACC: 0.7814814814814814 38 | Eval: Epoch: 29 dataset: SVHN ACC: 0.8315150583896742 39 | Eval: Epoch: 29 dataset: GTSRB ACC: 0.7098178939034046 40 | Eval: Epoch: 29 dataset: MNIST ACC: 0.9713 41 | Eval: Epoch: 29 dataset: DTD ACC: 0.5425531914893617 42 | Eval: Epoch: 29 Avg ACC:0.7326597653732856 43 | 44 | [tensor([1.0000, 0.3348, 0.3332, 0.3360, 0.3376, 0.3350, 0.3368, 0.3331, 0.3357])] 45 | Eval: Epoch: 39 dataset: SUN397 ACC: 0.6381555790214714 46 | Eval: Epoch: 39 dataset: Cars ACC: 0.6336276582514613 47 | Eval: Epoch: 39 dataset: RESISC45 ACC: 0.7442857142857143 48 | Eval: Epoch: 39 dataset: EuroSAT ACC: 0.782962962962963 49 | Eval: Epoch: 39 dataset: SVHN ACC: 0.8355485556238476 50 | Eval: Epoch: 39 dataset: GTSRB ACC: 0.7125890736342043 51 | Eval: Epoch: 39 dataset: MNIST ACC: 0.9724 52 | Eval: Epoch: 39 dataset: DTD ACC: 0.5430851063829787 53 | Eval: Epoch: 39 Avg ACC:0.73283183127033 54 | 55 | [tensor([1.0000, 0.3404, 0.3377, 0.3434, 0.3449, 0.3424, 0.3452, 0.3396, 0.3416])] 56 | Eval: Epoch: 49 dataset: SUN397 ACC: 0.6349876803942274 57 | Eval: Epoch: 49 dataset: Cars ACC: 0.6310160427807486 58 | Eval: Epoch: 49 dataset: RESISC45 ACC: 0.7417460317460317 59 | Eval: Epoch: 49 dataset: EuroSAT ACC: 0.782962962962963 60 | Eval: Epoch: 49 dataset: SVHN ACC: 0.8390058389674248 61 | Eval: Epoch: 49 dataset: GTSRB ACC: 0.7167062549485352 62 | Eval: Epoch: 49 dataset: MNIST ACC: 0.9734 63 | Eval: Epoch: 49 dataset: DTD ACC: 0.5425531914893617 64 | Eval: Epoch: 49 Avg ACC:0.7327972504111615 65 | 66 | [tensor([1.0000, 0.3443, 0.3403, 0.3495, 0.3520, 0.3486, 0.3531, 0.3444, 0.3460])] 67 | Eval: Epoch: 59 dataset: SUN397 ACC: 0.6309146678734852 68 | Eval: Epoch: 59 dataset: Cars ACC: 0.6280313393856486 69 | Eval: Epoch: 59 dataset: RESISC45 ACC: 0.7415873015873016 70 | Eval: Epoch: 59 dataset: EuroSAT ACC: 0.7837037037037037 71 | Eval: Epoch: 59 dataset: SVHN ACC: 0.8426167793484942 72 | Eval: Epoch: 59 dataset: GTSRB ACC: 0.7209817893903404 73 | Eval: Epoch: 59 dataset: MNIST ACC: 0.974 74 | Eval: Epoch: 59 dataset: DTD ACC: 0.5414893617021277 75 | Eval: Epoch: 59 Avg ACC:0.7329156178738878 76 | 77 | [tensor([1.0000, 0.3467, 0.3406, 0.3542, 0.3594, 0.3536, 0.3589, 0.3465, 0.3478])] 78 | Eval: Epoch: 69 dataset: SUN397 ACC: 0.6289535877709056 79 | Eval: Epoch: 69 dataset: Cars ACC: 0.6250466359905484 80 | Eval: Epoch: 69 dataset: RESISC45 ACC: 0.7409523809523809 81 | Eval: Epoch: 69 dataset: EuroSAT ACC: 0.7855555555555556 82 | Eval: Epoch: 69 dataset: SVHN ACC: 0.8449216349108789 83 | Eval: Epoch: 69 dataset: GTSRB ACC: 0.7228028503562945 84 | Eval: Epoch: 69 dataset: MNIST ACC: 0.9746 85 | Eval: Epoch: 69 dataset: DTD ACC: 0.5420212765957447 86 | Eval: Epoch: 69 Avg ACC:0.7331067402665385 87 | 88 | [tensor([1.0000, 0.3476, 0.3384, 0.3582, 0.3673, 0.3566, 0.3633, 0.3462, 0.3474])] 89 | Eval: Epoch: 79 dataset: SUN397 ACC: 0.6278976215618243 90 | Eval: Epoch: 79 dataset: Cars ACC: 0.6228081084442234 91 | Eval: Epoch: 79 dataset: RESISC45 ACC: 0.7414285714285714 92 | Eval: Epoch: 79 dataset: EuroSAT ACC: 0.7881481481481482 93 | Eval: Epoch: 79 dataset: SVHN ACC: 0.8458051628764598 94 | Eval: Epoch: 79 dataset: GTSRB ACC: 0.7253365003958828 95 | Eval: Epoch: 79 dataset: MNIST ACC: 0.9746 96 | Eval: Epoch: 79 dataset: DTD ACC: 0.5388297872340425 97 | Eval: Epoch: 79 Avg ACC:0.733106737511144 98 | 99 | [tensor([1.0000, 0.3470, 0.3339, 0.3620, 0.3748, 0.3599, 0.3683, 0.3457, 0.3450])] 100 | Eval: Epoch: 89 dataset: SUN397 ACC: 0.6264896666163826 101 | Eval: Epoch: 89 dataset: Cars ACC: 0.6218132073125233 102 | Eval: Epoch: 89 dataset: RESISC45 ACC: 0.7412698412698413 103 | Eval: Epoch: 89 dataset: EuroSAT ACC: 0.7911111111111111 104 | Eval: Epoch: 89 dataset: SVHN ACC: 0.8471496619545175 105 | Eval: Epoch: 89 dataset: GTSRB ACC: 0.7281076801266825 106 | Eval: Epoch: 89 dataset: MNIST ACC: 0.9746 107 | Eval: Epoch: 89 dataset: DTD ACC: 0.5382978723404256 108 | Eval: Epoch: 89 Avg ACC:0.7336048800914354 109 | 110 | [tensor([1.0000, 0.3472, 0.3296, 0.3660, 0.3814, 0.3640, 0.3740, 0.3456, 0.3443])] 111 | Eval: Epoch: 99 dataset: SUN397 ACC: 0.6250817116709408 112 | Eval: Epoch: 99 dataset: Cars ACC: 0.6195746797661983 113 | Eval: Epoch: 99 dataset: RESISC45 ACC: 0.7412698412698413 114 | Eval: Epoch: 99 dataset: EuroSAT ACC: 0.7937037037037037 115 | Eval: Epoch: 99 dataset: SVHN ACC: 0.8491472034419176 116 | Eval: Epoch: 99 dataset: GTSRB ACC: 0.7309580364212194 117 | Eval: Epoch: 99 dataset: MNIST ACC: 0.9748 118 | Eval: Epoch: 99 dataset: DTD ACC: 0.5361702127659574 119 | Eval: Epoch: 99 Avg ACC:0.7338381736299723 120 | 121 | [tensor([1.0000, 0.3466, 0.3237, 0.3691, 0.3876, 0.3676, 0.3796, 0.3439, 0.3424])] 122 | Eval: Epoch: 109 dataset: SUN397 ACC: 0.6240760295670539 123 | Eval: Epoch: 109 dataset: Cars ACC: 0.6174605148613357 124 | Eval: Epoch: 109 dataset: RESISC45 ACC: 0.7409523809523809 125 | Eval: Epoch: 109 dataset: EuroSAT ACC: 0.7959259259259259 126 | Eval: Epoch: 109 dataset: SVHN ACC: 0.8504532882606023 127 | Eval: Epoch: 109 dataset: GTSRB ACC: 0.7347585114806018 128 | Eval: Epoch: 109 dataset: MNIST ACC: 0.9746 129 | Eval: Epoch: 109 dataset: DTD ACC: 0.5351063829787234 130 | Eval: Epoch: 109 Avg ACC:0.7341666292533279 131 | 132 | [tensor([1.0000, 0.3455, 0.3182, 0.3716, 0.3935, 0.3712, 0.3852, 0.3420, 0.3414])] 133 | Eval: Epoch: 119 dataset: SUN397 ACC: 0.6230200633579726 134 | Eval: Epoch: 119 dataset: Cars ACC: 0.614848899390623 135 | Eval: Epoch: 119 dataset: RESISC45 ACC: 0.7409523809523809 136 | Eval: Epoch: 119 dataset: EuroSAT ACC: 0.797037037037037 137 | Eval: Epoch: 119 dataset: SVHN ACC: 0.8521051014136447 138 | Eval: Epoch: 119 dataset: GTSRB ACC: 0.7380839271575613 139 | Eval: Epoch: 119 dataset: MNIST ACC: 0.9743 140 | Eval: Epoch: 119 dataset: DTD ACC: 0.5351063829787234 141 | Eval: Epoch: 119 Avg ACC:0.7344317240359929 142 | 143 | [tensor([1.0000, 0.3425, 0.3124, 0.3743, 0.3994, 0.3754, 0.3905, 0.3398, 0.3399])] 144 | Eval: Epoch: 129 dataset: SUN397 ACC: 0.622366369990446 145 | Eval: Epoch: 129 dataset: Cars ACC: 0.6131078224101479 146 | Eval: Epoch: 129 dataset: RESISC45 ACC: 0.74 147 | Eval: Epoch: 129 dataset: EuroSAT ACC: 0.8003703703703704 148 | Eval: Epoch: 129 dataset: SVHN ACC: 0.8538721573448064 149 | Eval: Epoch: 129 dataset: GTSRB ACC: 0.7412509897070467 150 | Eval: Epoch: 129 dataset: MNIST ACC: 0.974 151 | Eval: Epoch: 129 dataset: DTD ACC: 0.5319148936170213 152 | Eval: Epoch: 129 Avg ACC:0.7346103254299798 153 | 154 | [tensor([1.0000, 0.3403, 0.3060, 0.3772, 0.4051, 0.3792, 0.3964, 0.3375, 0.3378])] 155 | Eval: Epoch: 139 dataset: SUN397 ACC: 0.6214612560969478 156 | Eval: Epoch: 139 dataset: Cars ACC: 0.6109936575052854 157 | Eval: Epoch: 139 dataset: RESISC45 ACC: 0.7395238095238095 158 | Eval: Epoch: 139 dataset: EuroSAT ACC: 0.8022222222222222 159 | Eval: Epoch: 139 dataset: SVHN ACC: 0.8551014136447449 160 | Eval: Epoch: 139 dataset: GTSRB ACC: 0.7441805225653206 161 | Eval: Epoch: 139 dataset: MNIST ACC: 0.9738 162 | Eval: Epoch: 139 dataset: DTD ACC: 0.5281914893617021 163 | Eval: Epoch: 139 Avg ACC:0.734434296365004 164 | 165 | [tensor([1.0000, 0.3377, 0.2994, 0.3791, 0.4113, 0.3832, 0.4020, 0.3357, 0.3349])] 166 | Eval: Epoch: 149 dataset: SUN397 ACC: 0.6211092673605874 167 | Eval: Epoch: 149 dataset: Cars ACC: 0.6088794926004228 168 | Eval: Epoch: 149 dataset: RESISC45 ACC: 0.7401587301587301 169 | Eval: Epoch: 149 dataset: EuroSAT ACC: 0.8037037037037037 170 | Eval: Epoch: 149 dataset: SVHN ACC: 0.8564074984634297 171 | Eval: Epoch: 149 dataset: GTSRB ACC: 0.7464766429136975 172 | Eval: Epoch: 149 dataset: MNIST ACC: 0.9736 173 | Eval: Epoch: 149 dataset: DTD ACC: 0.527127659574468 174 | Eval: Epoch: 149 Avg ACC:0.73468287434688 175 | 176 | [tensor([1.0000, 0.3358, 0.2940, 0.3806, 0.4176, 0.3850, 0.4085, 0.3323, 0.3329])] 177 | Eval: Epoch: 159 dataset: SUN397 ACC: 0.6205561422034495 178 | Eval: Epoch: 159 dataset: Cars ACC: 0.6071384156199477 179 | Eval: Epoch: 159 dataset: RESISC45 ACC: 0.7403174603174603 180 | Eval: Epoch: 159 dataset: EuroSAT ACC: 0.8062962962962963 181 | Eval: Epoch: 159 dataset: SVHN ACC: 0.8567148125384143 182 | Eval: Epoch: 159 dataset: GTSRB ACC: 0.7494061757719715 183 | Eval: Epoch: 159 dataset: MNIST ACC: 0.9729 184 | Eval: Epoch: 159 dataset: DTD ACC: 0.5260638297872341 185 | Eval: Epoch: 159 Avg ACC:0.7349241415668468 186 | 187 | [tensor([1.0000, 0.3339, 0.2893, 0.3821, 0.4232, 0.3858, 0.4146, 0.3282, 0.3306])] 188 | Eval: Epoch: 169 dataset: SUN397 ACC: 0.6201035852567004 189 | Eval: Epoch: 169 dataset: Cars ACC: 0.6055217012809352 190 | Eval: Epoch: 169 dataset: RESISC45 ACC: 0.7404761904761905 191 | Eval: Epoch: 169 dataset: EuroSAT ACC: 0.8107407407407408 192 | Eval: Epoch: 169 dataset: SVHN ACC: 0.8567916410571604 193 | Eval: Epoch: 169 dataset: GTSRB ACC: 0.7523357086302455 194 | Eval: Epoch: 169 dataset: MNIST ACC: 0.9727 195 | Eval: Epoch: 169 dataset: DTD ACC: 0.5260638297872341 196 | Eval: Epoch: 169 Avg ACC:0.7355916746536508 197 | 198 | [tensor([1.0000, 0.3329, 0.2844, 0.3846, 0.4307, 0.3880, 0.4196, 0.3265, 0.3289])] 199 | Eval: Epoch: 179 dataset: SUN397 ACC: 0.6191481872580078 200 | Eval: Epoch: 179 dataset: Cars ACC: 0.6034075363760726 201 | Eval: Epoch: 179 dataset: RESISC45 ACC: 0.7395238095238095 202 | Eval: Epoch: 179 dataset: EuroSAT ACC: 0.8144444444444444 203 | Eval: Epoch: 179 dataset: SVHN ACC: 0.8572141979102643 204 | Eval: Epoch: 179 dataset: GTSRB ACC: 0.7551860649247822 205 | Eval: Epoch: 179 dataset: MNIST ACC: 0.9722 206 | Eval: Epoch: 179 dataset: DTD ACC: 0.525 207 | Eval: Epoch: 179 Avg ACC:0.7357655300546726 208 | 209 | [tensor([1.0000, 0.3320, 0.2790, 0.3867, 0.4364, 0.3894, 0.4233, 0.3229, 0.3281])] 210 | Eval: Epoch: 189 dataset: SUN397 ACC: 0.6185950621008699 211 | Eval: Epoch: 189 dataset: Cars ACC: 0.6029100858102227 212 | Eval: Epoch: 189 dataset: RESISC45 ACC: 0.74 213 | Eval: Epoch: 189 dataset: EuroSAT ACC: 0.8166666666666667 214 | Eval: Epoch: 189 dataset: SVHN ACC: 0.8573678549477566 215 | Eval: Epoch: 189 dataset: GTSRB ACC: 0.7566112430720506 216 | Eval: Epoch: 189 dataset: MNIST ACC: 0.9717 217 | Eval: Epoch: 189 dataset: DTD ACC: 0.523936170212766 218 | Eval: Epoch: 189 Avg ACC:0.7359733853512915 219 | 220 | [tensor([1.0000, 0.3317, 0.2732, 0.3899, 0.4427, 0.3901, 0.4268, 0.3188, 0.3264])] 221 | Eval: Epoch: 199 dataset: SUN397 ACC: 0.6182430733645095 222 | Eval: Epoch: 199 dataset: Cars ACC: 0.60179082203706 223 | Eval: Epoch: 199 dataset: RESISC45 ACC: 0.7412698412698413 224 | Eval: Epoch: 199 dataset: EuroSAT ACC: 0.8218518518518518 225 | Eval: Epoch: 199 dataset: SVHN ACC: 0.8574446834665027 226 | Eval: Epoch: 199 dataset: GTSRB ACC: 0.7584323040380048 227 | Eval: Epoch: 199 dataset: MNIST ACC: 0.9713 228 | Eval: Epoch: 199 dataset: DTD ACC: 0.524468085106383 229 | Eval: Epoch: 199 Avg ACC:0.7368500826417692 230 | 231 | [tensor([1.0000, 0.3308, 0.2676, 0.3921, 0.4492, 0.3902, 0.4309, 0.3136, 0.3253])] 232 | Eval: Epoch: 209 dataset: SUN397 ACC: 0.6182933574697038 233 | Eval: Epoch: 209 dataset: Cars ACC: 0.6006715582638975 234 | Eval: Epoch: 209 dataset: RESISC45 ACC: 0.7420634920634921 235 | Eval: Epoch: 209 dataset: EuroSAT ACC: 0.8248148148148148 236 | Eval: Epoch: 209 dataset: SVHN ACC: 0.8571373693915181 237 | Eval: Epoch: 209 dataset: GTSRB ACC: 0.7599366587490103 238 | Eval: Epoch: 209 dataset: MNIST ACC: 0.9707 239 | Eval: Epoch: 209 dataset: DTD ACC: 0.5234042553191489 240 | Eval: Epoch: 209 Avg ACC:0.7371276882589483 241 | 242 | [tensor([1.0000, 0.3294, 0.2619, 0.3941, 0.4549, 0.3916, 0.4361, 0.3092, 0.3243])] 243 | Eval: Epoch: 219 dataset: SUN397 ACC: 0.6176899482073717 244 | Eval: Epoch: 219 dataset: Cars ACC: 0.59880611864196 245 | Eval: Epoch: 219 dataset: RESISC45 ACC: 0.7423809523809524 246 | Eval: Epoch: 219 dataset: EuroSAT ACC: 0.827037037037037 247 | Eval: Epoch: 219 dataset: SVHN ACC: 0.8574830977258758 248 | Eval: Epoch: 219 dataset: GTSRB ACC: 0.7624703087885986 249 | Eval: Epoch: 219 dataset: MNIST ACC: 0.9698 250 | Eval: Epoch: 219 dataset: DTD ACC: 0.523936170212766 251 | Eval: Epoch: 219 Avg ACC:0.7374504541243202 252 | 253 | [tensor([1.0000, 0.3291, 0.2572, 0.3962, 0.4600, 0.3936, 0.4420, 0.3061, 0.3259])] 254 | Eval: Epoch: 229 dataset: SUN397 ACC: 0.6171368230502339 255 | Eval: Epoch: 229 dataset: Cars ACC: 0.5971894043029474 256 | Eval: Epoch: 229 dataset: RESISC45 ACC: 0.7412698412698413 257 | Eval: Epoch: 229 dataset: EuroSAT ACC: 0.8274074074074074 258 | Eval: Epoch: 229 dataset: SVHN ACC: 0.8582513829133375 259 | Eval: Epoch: 229 dataset: GTSRB ACC: 0.7649247822644497 260 | Eval: Epoch: 229 dataset: MNIST ACC: 0.969 261 | Eval: Epoch: 229 dataset: DTD ACC: 0.5228723404255319 262 | Eval: Epoch: 229 Avg ACC:0.7372564977042186 263 | 264 | [tensor([1.0000, 0.3283, 0.2522, 0.3975, 0.4634, 0.3958, 0.4489, 0.3045, 0.3253])] 265 | Eval: Epoch: 239 dataset: SUN397 ACC: 0.6169356866294564 266 | Eval: Epoch: 239 dataset: Cars ACC: 0.5950752393980848 267 | Eval: Epoch: 239 dataset: RESISC45 ACC: 0.7404761904761905 268 | Eval: Epoch: 239 dataset: EuroSAT ACC: 0.8285185185185185 269 | Eval: Epoch: 239 dataset: SVHN ACC: 0.859519053472649 270 | Eval: Epoch: 239 dataset: GTSRB ACC: 0.7683293745051465 271 | Eval: Epoch: 239 dataset: MNIST ACC: 0.9686 272 | Eval: Epoch: 239 dataset: DTD ACC: 0.5218085106382979 273 | Eval: Epoch: 239 Avg ACC:0.737407821704793 274 | 275 | [tensor([1.0000, 0.3274, 0.2476, 0.3985, 0.4663, 0.3978, 0.4552, 0.3022, 0.3211])] 276 | Eval: Epoch: 249 dataset: SUN397 ACC: 0.6163825614723186 277 | Eval: Epoch: 249 dataset: Cars ACC: 0.5945777888322348 278 | Eval: Epoch: 249 dataset: RESISC45 ACC: 0.7403174603174603 279 | Eval: Epoch: 249 dataset: EuroSAT ACC: 0.8288888888888889 280 | Eval: Epoch: 249 dataset: SVHN ACC: 0.8600952673632453 281 | Eval: Epoch: 249 dataset: GTSRB ACC: 0.7714964370546318 282 | Eval: Epoch: 249 dataset: MNIST ACC: 0.9682 283 | Eval: Epoch: 249 dataset: DTD ACC: 0.5202127659574468 284 | Eval: Epoch: 249 Avg ACC:0.7375213962357782 285 | 286 | [tensor([1.0000, 0.3266, 0.2447, 0.3996, 0.4698, 0.4008, 0.4610, 0.3012, 0.3177])] 287 | Eval: Epoch: 259 dataset: SUN397 ACC: 0.6157791522099865 288 | Eval: Epoch: 259 dataset: Cars ACC: 0.5938316129834598 289 | Eval: Epoch: 259 dataset: RESISC45 ACC: 0.7392063492063492 290 | Eval: Epoch: 259 dataset: EuroSAT ACC: 0.8288888888888889 291 | Eval: Epoch: 259 dataset: SVHN ACC: 0.8614013521819299 292 | Eval: Epoch: 259 dataset: GTSRB ACC: 0.7749010292953286 293 | Eval: Epoch: 259 dataset: MNIST ACC: 0.9679 294 | Eval: Epoch: 259 dataset: DTD ACC: 0.5191489361702127 295 | Eval: Epoch: 259 Avg ACC:0.7376321651170195 296 | 297 | [tensor([1.0000, 0.3257, 0.2417, 0.4008, 0.4722, 0.4042, 0.4676, 0.3001, 0.3163])] 298 | Eval: Epoch: 269 dataset: SUN397 ACC: 0.6146729018957108 299 | Eval: Epoch: 269 dataset: Cars ACC: 0.5920905360029847 300 | Eval: Epoch: 269 dataset: RESISC45 ACC: 0.7384126984126984 301 | Eval: Epoch: 269 dataset: EuroSAT ACC: 0.8292592592592593 302 | Eval: Epoch: 269 dataset: SVHN ACC: 0.8626690227412416 303 | Eval: Epoch: 269 dataset: GTSRB ACC: 0.7787806809184481 304 | Eval: Epoch: 269 dataset: MNIST ACC: 0.9679 305 | Eval: Epoch: 269 dataset: DTD ACC: 0.5175531914893617 306 | Eval: Epoch: 269 Avg ACC:0.737667286339963 307 | 308 | [tensor([1.0000, 0.3249, 0.2389, 0.4023, 0.4755, 0.4077, 0.4746, 0.2987, 0.3152])] 309 | Eval: Epoch: 279 dataset: SUN397 ACC: 0.6138180721074068 310 | Eval: Epoch: 279 dataset: Cars ACC: 0.5894789205322721 311 | Eval: Epoch: 279 dataset: RESISC45 ACC: 0.7376190476190476 312 | Eval: Epoch: 279 dataset: EuroSAT ACC: 0.8288888888888889 313 | Eval: Epoch: 279 dataset: SVHN ACC: 0.8640519360786724 314 | Eval: Epoch: 279 dataset: GTSRB ACC: 0.7829770387965163 315 | Eval: Epoch: 279 dataset: MNIST ACC: 0.9678 316 | Eval: Epoch: 279 dataset: DTD ACC: 0.5164893617021277 317 | Eval: Epoch: 279 Avg ACC:0.7376404082156165 318 | 319 | [tensor([1.0000, 0.3232, 0.2341, 0.4022, 0.4776, 0.4107, 0.4808, 0.2955, 0.3127])] 320 | Eval: Epoch: 289 dataset: SUN397 ACC: 0.6135163674762407 321 | Eval: Epoch: 289 dataset: Cars ACC: 0.5883596567591096 322 | Eval: Epoch: 289 dataset: RESISC45 ACC: 0.7368253968253968 323 | Eval: Epoch: 289 dataset: EuroSAT ACC: 0.8274074074074074 324 | Eval: Epoch: 289 dataset: SVHN ACC: 0.865319606637984 325 | Eval: Epoch: 289 dataset: GTSRB ACC: 0.7855106888361045 326 | Eval: Epoch: 289 dataset: MNIST ACC: 0.9671 327 | Eval: Epoch: 289 dataset: DTD ACC: 0.5159574468085106 328 | Eval: Epoch: 289 Avg ACC:0.7374995713438441 329 | 330 | [tensor([1.0000, 0.3215, 0.2290, 0.4009, 0.4804, 0.4135, 0.4863, 0.2920, 0.3091])] 331 | Eval: Epoch: 299 dataset: SUN397 ACC: 0.6132146628450746 332 | Eval: Epoch: 299 dataset: Cars ACC: 0.587240392985947 333 | Eval: Epoch: 299 dataset: RESISC45 ACC: 0.7363492063492063 334 | Eval: Epoch: 299 dataset: EuroSAT ACC: 0.8281481481481482 335 | Eval: Epoch: 299 dataset: SVHN ACC: 0.8663952059004303 336 | Eval: Epoch: 299 dataset: GTSRB ACC: 0.7885193982581156 337 | Eval: Epoch: 299 dataset: MNIST ACC: 0.9668 338 | Eval: Epoch: 299 dataset: DTD ACC: 0.5154255319148936 339 | Eval: Epoch: 299 Avg ACC:0.737761568300227 340 | 341 | [tensor([1.0000, 0.3209, 0.2266, 0.3999, 0.4833, 0.4156, 0.4916, 0.2906, 0.3074])] 342 | Eval: Epoch: 309 dataset: SUN397 ACC: 0.6129632423191029 343 | Eval: Epoch: 309 dataset: Cars ACC: 0.586494217137172 344 | Eval: Epoch: 309 dataset: RESISC45 ACC: 0.734920634920635 345 | Eval: Epoch: 309 dataset: EuroSAT ACC: 0.8277777777777777 346 | Eval: Epoch: 309 dataset: SVHN ACC: 0.8671634910878918 347 | Eval: Epoch: 309 dataset: GTSRB ACC: 0.7908155186064925 348 | Eval: Epoch: 309 dataset: MNIST ACC: 0.9667 349 | Eval: Epoch: 309 dataset: DTD ACC: 0.5148936170212766 350 | Eval: Epoch: 309 Avg ACC:0.7377160623587936 351 | 352 | [tensor([1.0000, 0.3216, 0.2241, 0.4017, 0.4862, 0.4166, 0.4966, 0.2893, 0.3083])] 353 | Eval: Epoch: 319 dataset: SUN397 ACC: 0.6126615376879369 354 | Eval: Epoch: 319 dataset: Cars ACC: 0.5853749533640095 355 | Eval: Epoch: 319 dataset: RESISC45 ACC: 0.7347619047619047 356 | Eval: Epoch: 319 dataset: EuroSAT ACC: 0.8281481481481482 357 | Eval: Epoch: 319 dataset: SVHN ACC: 0.8670866625691457 358 | Eval: Epoch: 319 dataset: GTSRB ACC: 0.7926365795724466 359 | Eval: Epoch: 319 dataset: MNIST ACC: 0.9663 360 | Eval: Epoch: 319 dataset: DTD ACC: 0.5148936170212766 361 | Eval: Epoch: 319 Avg ACC:0.7377329253906085 362 | 363 | [tensor([1.0000, 0.3216, 0.2195, 0.4032, 0.4903, 0.4195, 0.5017, 0.2877, 0.3084])] 364 | Eval: Epoch: 329 dataset: SUN397 ACC: 0.6120581284256047 365 | Eval: Epoch: 329 dataset: Cars ACC: 0.5843800522323094 366 | Eval: Epoch: 329 dataset: RESISC45 ACC: 0.7347619047619047 367 | Eval: Epoch: 329 dataset: EuroSAT ACC: 0.8296296296296296 368 | Eval: Epoch: 329 dataset: SVHN ACC: 0.8681238475722188 369 | Eval: Epoch: 329 dataset: GTSRB ACC: 0.7945368171021377 370 | Eval: Epoch: 329 dataset: MNIST ACC: 0.9661 371 | Eval: Epoch: 329 dataset: DTD ACC: 0.5138297872340426 372 | Eval: Epoch: 329 Avg ACC:0.7379275208697309 373 | 374 | [tensor([1.0000, 0.3224, 0.2152, 0.4049, 0.4936, 0.4227, 0.5080, 0.2860, 0.3086])] 375 | Eval: Epoch: 339 dataset: SUN397 ACC: 0.6115552873736612 376 | Eval: Epoch: 339 dataset: Cars ACC: 0.5826389752518344 377 | Eval: Epoch: 339 dataset: RESISC45 ACC: 0.7336507936507937 378 | Eval: Epoch: 339 dataset: EuroSAT ACC: 0.8296296296296296 379 | Eval: Epoch: 339 dataset: SVHN ACC: 0.8690457897971727 380 | Eval: Epoch: 339 dataset: GTSRB ACC: 0.7966745843230404 381 | Eval: Epoch: 339 dataset: MNIST ACC: 0.9661 382 | Eval: Epoch: 339 dataset: DTD ACC: 0.5122340425531915 383 | Eval: Epoch: 339 Avg ACC:0.7376911378224155 384 | 385 | [tensor([1.0000, 0.3216, 0.2118, 0.4060, 0.4965, 0.4258, 0.5153, 0.2827, 0.3070])] 386 | Eval: Epoch: 349 dataset: SUN397 ACC: 0.6108010257957459 387 | Eval: Epoch: 349 dataset: Cars ACC: 0.5812709861957468 388 | Eval: Epoch: 349 dataset: RESISC45 ACC: 0.7322222222222222 389 | Eval: Epoch: 349 dataset: EuroSAT ACC: 0.8296296296296296 390 | Eval: Epoch: 349 dataset: SVHN ACC: 0.8699677320221266 391 | Eval: Epoch: 349 dataset: GTSRB ACC: 0.7996832937450514 392 | Eval: Epoch: 349 dataset: MNIST ACC: 0.9655 393 | Eval: Epoch: 349 dataset: DTD ACC: 0.5117021276595745 394 | Eval: Epoch: 349 Avg ACC:0.7375971271587621 395 | 396 | [tensor([1.0000, 0.3193, 0.2079, 0.4089, 0.4994, 0.4267, 0.5212, 0.2778, 0.3044])] 397 | Eval: Epoch: 359 dataset: SUN397 ACC: 0.6105496052697742 398 | Eval: Epoch: 359 dataset: Cars ACC: 0.5806491729884343 399 | Eval: Epoch: 359 dataset: RESISC45 ACC: 0.7331746031746031 400 | Eval: Epoch: 359 dataset: EuroSAT ACC: 0.8311111111111111 401 | Eval: Epoch: 359 dataset: SVHN ACC: 0.8700061462814997 402 | Eval: Epoch: 359 dataset: GTSRB ACC: 0.8027711797307997 403 | Eval: Epoch: 359 dataset: MNIST ACC: 0.9647 404 | Eval: Epoch: 359 dataset: DTD ACC: 0.5111702127659574 405 | Eval: Epoch: 359 Avg ACC:0.7380165039152725 406 | 407 | [tensor([1.0000, 0.3177, 0.2036, 0.4110, 0.5017, 0.4284, 0.5256, 0.2744, 0.3035])] 408 | Eval: Epoch: 369 dataset: SUN397 ACC: 0.6101976165334138 409 | Eval: Epoch: 369 dataset: Cars ACC: 0.5799029971396592 410 | Eval: Epoch: 369 dataset: RESISC45 ACC: 0.733015873015873 411 | Eval: Epoch: 369 dataset: EuroSAT ACC: 0.8314814814814815 412 | Eval: Epoch: 369 dataset: SVHN ACC: 0.8705055316533498 413 | Eval: Epoch: 369 dataset: GTSRB ACC: 0.8051464766429137 414 | Eval: Epoch: 369 dataset: MNIST ACC: 0.9642 415 | Eval: Epoch: 369 dataset: DTD ACC: 0.5101063829787233 416 | Eval: Epoch: 369 Avg ACC:0.7380695449306767 417 | 418 | [tensor([1.0000, 0.3179, 0.2009, 0.4134, 0.5041, 0.4309, 0.5308, 0.2733, 0.3036])] 419 | Eval: Epoch: 379 dataset: SUN397 ACC: 0.6091919344295268 420 | Eval: Epoch: 379 dataset: Cars ACC: 0.5796542718567342 421 | Eval: Epoch: 379 dataset: RESISC45 ACC: 0.7331746031746031 422 | Eval: Epoch: 379 dataset: EuroSAT ACC: 0.8322222222222222 423 | Eval: Epoch: 379 dataset: SVHN ACC: 0.8710049170251998 424 | Eval: Epoch: 379 dataset: GTSRB ACC: 0.806492478226445 425 | Eval: Epoch: 379 dataset: MNIST ACC: 0.9641 426 | Eval: Epoch: 379 dataset: DTD ACC: 0.5101063829787233 427 | Eval: Epoch: 379 Avg ACC:0.7382433512391817 428 | 429 | [tensor([1.0000, 0.3179, 0.1964, 0.4167, 0.5077, 0.4299, 0.5343, 0.2700, 0.3007])] 430 | Eval: Epoch: 389 dataset: SUN397 ACC: 0.609795343691859 431 | Eval: Epoch: 389 dataset: Cars ACC: 0.5782862828006466 432 | Eval: Epoch: 389 dataset: RESISC45 ACC: 0.7339682539682539 433 | Eval: Epoch: 389 dataset: EuroSAT ACC: 0.8340740740740741 434 | Eval: Epoch: 389 dataset: SVHN ACC: 0.870620774431469 435 | Eval: Epoch: 389 dataset: GTSRB ACC: 0.8079176563737134 436 | Eval: Epoch: 389 dataset: MNIST ACC: 0.9637 437 | Eval: Epoch: 389 dataset: DTD ACC: 0.5085106382978724 438 | Eval: Epoch: 389 Avg ACC:0.738359127954736 439 | 440 | [tensor([1.0000, 0.3180, 0.1922, 0.4197, 0.5098, 0.4293, 0.5374, 0.2653, 0.2978])] 441 | Eval: Epoch: 399 dataset: SUN397 ACC: 0.6103484688489969 442 | Eval: Epoch: 399 dataset: Cars ACC: 0.5777888322347967 443 | Eval: Epoch: 399 dataset: RESISC45 ACC: 0.7347619047619047 444 | Eval: Epoch: 399 dataset: EuroSAT ACC: 0.8348148148148148 445 | Eval: Epoch: 399 dataset: SVHN ACC: 0.870159803318992 446 | Eval: Epoch: 399 dataset: GTSRB ACC: 0.809501187648456 447 | Eval: Epoch: 399 dataset: MNIST ACC: 0.9628 448 | Eval: Epoch: 399 dataset: DTD ACC: 0.5074468085106383 449 | Eval: Epoch: 399 Avg ACC:0.7384527275173249 450 | 451 | [tensor([1.0000, 0.3186, 0.1893, 0.4205, 0.5113, 0.4312, 0.5423, 0.2636, 0.2945])] 452 | Eval: Epoch: 409 dataset: SUN397 ACC: 0.6102479006386081 453 | Eval: Epoch: 409 dataset: Cars ACC: 0.5770426563860216 454 | Eval: Epoch: 409 dataset: RESISC45 ACC: 0.7344444444444445 455 | Eval: Epoch: 409 dataset: EuroSAT ACC: 0.8348148148148148 456 | Eval: Epoch: 409 dataset: SVHN ACC: 0.870620774431469 457 | Eval: Epoch: 409 dataset: GTSRB ACC: 0.8119556611243072 458 | Eval: Epoch: 409 dataset: MNIST ACC: 0.9624 459 | Eval: Epoch: 409 dataset: DTD ACC: 0.5063829787234042 460 | Eval: Epoch: 409 Avg ACC:0.7384886538203836 461 | 462 | [tensor([1.0000, 0.3190, 0.1867, 0.4226, 0.5133, 0.4322, 0.5460, 0.2617, 0.2912])] 463 | Eval: Epoch: 419 dataset: SUN397 ACC: 0.6103484688489969 464 | Eval: Epoch: 419 dataset: Cars ACC: 0.5757990299713966 465 | Eval: Epoch: 419 dataset: RESISC45 ACC: 0.734920634920635 466 | Eval: Epoch: 419 dataset: EuroSAT ACC: 0.8348148148148148 467 | Eval: Epoch: 419 dataset: SVHN ACC: 0.8708512599877074 468 | Eval: Epoch: 419 dataset: GTSRB ACC: 0.8129057798891528 469 | Eval: Epoch: 419 dataset: MNIST ACC: 0.9621 470 | Eval: Epoch: 419 dataset: DTD ACC: 0.5042553191489362 471 | Eval: Epoch: 419 Avg ACC:0.738249413447705 472 | 473 | [tensor([1.0000, 0.3194, 0.1848, 0.4243, 0.5150, 0.4321, 0.5501, 0.2585, 0.2900])] 474 | Eval: Epoch: 429 dataset: SUN397 ACC: 0.6098456277970533 475 | Eval: Epoch: 429 dataset: Cars ACC: 0.5745554035567716 476 | Eval: Epoch: 429 dataset: RESISC45 ACC: 0.7350793650793651 477 | Eval: Epoch: 429 dataset: EuroSAT ACC: 0.8366666666666667 478 | Eval: Epoch: 429 dataset: SVHN ACC: 0.8705823601720959 479 | Eval: Epoch: 429 dataset: GTSRB ACC: 0.8142517814726841 480 | Eval: Epoch: 429 dataset: MNIST ACC: 0.9612 481 | Eval: Epoch: 429 dataset: DTD ACC: 0.5042553191489362 482 | Eval: Epoch: 429 Avg ACC:0.7383045654866965 483 | 484 | [tensor([1.0000, 0.3198, 0.1828, 0.4259, 0.5156, 0.4334, 0.5541, 0.2581, 0.2910])] 485 | Eval: Epoch: 439 dataset: SUN397 ACC: 0.6091919344295268 486 | Eval: Epoch: 439 dataset: Cars ACC: 0.5740579529909215 487 | Eval: Epoch: 439 dataset: RESISC45 ACC: 0.7350793650793651 488 | Eval: Epoch: 439 dataset: EuroSAT ACC: 0.8366666666666667 489 | Eval: Epoch: 439 dataset: SVHN ACC: 0.8709665027658267 490 | Eval: Epoch: 439 dataset: GTSRB ACC: 0.8155977830562153 491 | Eval: Epoch: 439 dataset: MNIST ACC: 0.9611 492 | Eval: Epoch: 439 dataset: DTD ACC: 0.5037234042553191 493 | Eval: Epoch: 439 Avg ACC:0.73829795115548 494 | 495 | [tensor([1.0000, 0.3201, 0.1804, 0.4269, 0.5161, 0.4341, 0.5577, 0.2541, 0.2909])] 496 | Eval: Epoch: 449 dataset: SUN397 ACC: 0.6094936390606929 497 | Eval: Epoch: 449 dataset: Cars ACC: 0.5733117771421465 498 | Eval: Epoch: 449 dataset: RESISC45 ACC: 0.7353968253968254 499 | Eval: Epoch: 449 dataset: EuroSAT ACC: 0.837037037037037 500 | Eval: Epoch: 449 dataset: SVHN ACC: 0.8711969883220652 501 | Eval: Epoch: 449 dataset: GTSRB ACC: 0.8166270783847981 502 | Eval: Epoch: 449 dataset: MNIST ACC: 0.96 503 | Eval: Epoch: 449 dataset: DTD ACC: 0.5037234042553191 504 | Eval: Epoch: 449 Avg ACC:0.7383483436998605 505 | 506 | [tensor([1.0000, 0.3190, 0.1755, 0.4256, 0.5159, 0.4341, 0.5609, 0.2507, 0.2898])] 507 | Eval: Epoch: 459 dataset: SUN397 ACC: 0.6099964801126364 508 | Eval: Epoch: 459 dataset: Cars ACC: 0.5723168760104465 509 | Eval: Epoch: 459 dataset: RESISC45 ACC: 0.7353968253968254 510 | Eval: Epoch: 459 dataset: EuroSAT ACC: 0.8362962962962963 511 | Eval: Epoch: 459 dataset: SVHN ACC: 0.8712354025814383 512 | Eval: Epoch: 459 dataset: GTSRB ACC: 0.8187648456057007 513 | Eval: Epoch: 459 dataset: MNIST ACC: 0.959 514 | Eval: Epoch: 459 dataset: DTD ACC: 0.5037234042553191 515 | Eval: Epoch: 459 Avg ACC:0.7383412662823328 516 | 517 | [tensor([1.0000, 0.3193, 0.1727, 0.4249, 0.5163, 0.4341, 0.5663, 0.2486, 0.2893])] 518 | Eval: Epoch: 469 dataset: SUN397 ACC: 0.609946196007442 519 | Eval: Epoch: 469 dataset: Cars ACC: 0.5714463375202089 520 | Eval: Epoch: 469 dataset: RESISC45 ACC: 0.7352380952380952 521 | Eval: Epoch: 469 dataset: EuroSAT ACC: 0.8351851851851851 522 | Eval: Epoch: 469 dataset: SVHN ACC: 0.8711969883220652 523 | Eval: Epoch: 469 dataset: GTSRB ACC: 0.8205067300079176 524 | Eval: Epoch: 469 dataset: MNIST ACC: 0.9586 525 | Eval: Epoch: 469 dataset: DTD ACC: 0.5037234042553191 526 | Eval: Epoch: 469 Avg ACC:0.7382303670670292 527 | 528 | [tensor([1.0000, 0.3207, 0.1719, 0.4247, 0.5183, 0.4351, 0.5730, 0.2475, 0.2902])] 529 | Eval: Epoch: 479 dataset: SUN397 ACC: 0.6091919344295268 530 | Eval: Epoch: 479 dataset: Cars ACC: 0.5700783484641214 531 | Eval: Epoch: 479 dataset: RESISC45 ACC: 0.7333333333333333 532 | Eval: Epoch: 479 dataset: EuroSAT ACC: 0.835925925925926 533 | Eval: Epoch: 479 dataset: SVHN ACC: 0.8717347879532883 534 | Eval: Epoch: 479 dataset: GTSRB ACC: 0.8224069675376089 535 | Eval: Epoch: 479 dataset: MNIST ACC: 0.9582 536 | Eval: Epoch: 479 dataset: DTD ACC: 0.5031914893617021 537 | Eval: Epoch: 479 Avg ACC:0.7380078483756884 538 | 539 | [tensor([1.0000, 0.3207, 0.1718, 0.4252, 0.5196, 0.4364, 0.5769, 0.2461, 0.2894])] 540 | Eval: Epoch: 489 dataset: SUN397 ACC: 0.6088399456931664 541 | Eval: Epoch: 489 dataset: Cars ACC: 0.5693321726153463 542 | Eval: Epoch: 489 dataset: RESISC45 ACC: 0.733015873015873 543 | Eval: Epoch: 489 dataset: EuroSAT ACC: 0.8362962962962963 544 | Eval: Epoch: 489 dataset: SVHN ACC: 0.8721189305470191 545 | Eval: Epoch: 489 dataset: GTSRB ACC: 0.8234362628661916 546 | Eval: Epoch: 489 dataset: MNIST ACC: 0.9578 547 | Eval: Epoch: 489 dataset: DTD ACC: 0.502127659574468 548 | Eval: Epoch: 489 Avg ACC:0.737870892576045 549 | 550 | [tensor([1.0000, 0.3171, 0.1698, 0.4235, 0.5198, 0.4386, 0.5803, 0.2452, 0.2885])] 551 | Eval: Epoch: 499 dataset: SUN397 ACC: 0.6081862523256398 552 | Eval: Epoch: 499 dataset: Cars ACC: 0.5690834473324213 553 | Eval: Epoch: 499 dataset: RESISC45 ACC: 0.7317460317460317 554 | Eval: Epoch: 499 dataset: EuroSAT ACC: 0.8348148148148148 555 | Eval: Epoch: 499 dataset: SVHN ACC: 0.8730792870313461 556 | Eval: Epoch: 499 dataset: GTSRB ACC: 0.8247822644497229 557 | Eval: Epoch: 499 dataset: MNIST ACC: 0.9577 558 | Eval: Epoch: 499 dataset: DTD ACC: 0.5015957446808511 559 | Eval: Epoch: 499 Avg ACC:0.7376234802976035 560 | -------------------------------------------------------------------------------- /src/args.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | 4 | import torch 5 | 6 | def parse_arguments(): 7 | parser = argparse.ArgumentParser() 8 | parser.add_argument( 9 | "--data-location", 10 | type=str, 11 | default=os.path.expanduser('~/data'), 12 | help="The root directory for the datasets.", 13 | ) 14 | parser.add_argument( 15 | "--eval-datasets", 16 | default=None, 17 | type=lambda x: x.split(","), 18 | help="Which datasets to use for evaluation. Split by comma, e.g. MNIST,EuroSAT. " 19 | ) 20 | parser.add_argument( 21 | "--train-dataset", 22 | default=None, 23 | type=lambda x: x.split(","), 24 | help="Which dataset(s) to patch on.", 25 | ) 26 | parser.add_argument( 27 | "--exp_name", 28 | type=str, 29 | default=None, 30 | help="Name of the experiment, for organization purposes only." 31 | ) 32 | parser.add_argument( 33 | "--results-db", 34 | type=str, 35 | default=None, 36 | help="Where to store the results, else does not store", 37 | ) 38 | parser.add_argument( 39 | "--model", 40 | type=str, 41 | default=None, 42 | help="The type of model (e.g. RN50, ViT-B-32).", 43 | ) 44 | parser.add_argument( 45 | "--batch-size", 46 | type=int, 47 | default=128, 48 | ) 49 | parser.add_argument( 50 | "--lr", 51 | type=float, 52 | default=0.001, 53 | help="Learning rate." 54 | ) 55 | parser.add_argument( 56 | "--wd", 57 | type=float, 58 | default=0.1, 59 | help="Weight decay" 60 | ) 61 | parser.add_argument( 62 | "--ls", 63 | type=float, 64 | default=0.0, 65 | help="Label smoothing." 66 | ) 67 | parser.add_argument( 68 | "--warmup_length", 69 | type=int, 70 | default=500, 71 | ) 72 | parser.add_argument( 73 | "--epochs", 74 | type=int, 75 | default=10, 76 | ) 77 | parser.add_argument( 78 | "--load", 79 | type=lambda x: x.split(","), 80 | default=None, 81 | help="Optionally load _classifiers_, e.g. a zero shot classifier or probe or ensemble both.", 82 | ) 83 | parser.add_argument( 84 | "--save", 85 | type=str, 86 | default=None, 87 | help="Optionally save a _classifier_, e.g. a zero shot classifier or probe.", 88 | ) 89 | parser.add_argument( 90 | "--cache-dir", 91 | type=str, 92 | default=None, 93 | help="Directory for caching features and encoder", 94 | ) 95 | parser.add_argument( 96 | "--openclip-cachedir", 97 | type=str, 98 | default='/gscratch/efml/gamaga/.cache/open_clip', 99 | help='Directory for caching models from OpenCLIP' 100 | ) 101 | parsed_args = parser.parse_args() 102 | parsed_args.device = "cuda" if torch.cuda.is_available() else "cpu" 103 | 104 | if parsed_args.load is not None and len(parsed_args.load) == 1: 105 | parsed_args.load = parsed_args.load[0] 106 | return parsed_args 107 | -------------------------------------------------------------------------------- /src/datasets/cars.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torchvision.datasets as datasets 4 | 5 | 6 | import pathlib 7 | from typing import Callable, Optional, Any, Tuple 8 | 9 | from PIL import Image 10 | 11 | from torchvision.datasets.utils import download_and_extract_archive, download_url, verify_str_arg 12 | from torchvision.datasets.vision import VisionDataset 13 | 14 | 15 | class PytorchStanfordCars(VisionDataset): 16 | """`Stanford Cars `_ Dataset 17 | 18 | The Cars dataset contains 16,185 images of 196 classes of cars. The data is 19 | split into 8,144 training images and 8,041 testing images, where each class 20 | has been split roughly in a 50-50 split 21 | 22 | .. note:: 23 | 24 | This class needs `scipy `_ to load target files from `.mat` format. 25 | 26 | Args: 27 | root (string): Root directory of dataset 28 | split (string, optional): The dataset split, supports ``"train"`` (default) or ``"test"``. 29 | transform (callable, optional): A function/transform that takes in an PIL image 30 | and returns a transformed version. E.g, ``transforms.RandomCrop`` 31 | target_transform (callable, optional): A function/transform that takes in the 32 | target and transforms it. 33 | download (bool, optional): If True, downloads the dataset from the internet and 34 | puts it in root directory. If dataset is already downloaded, it is not 35 | downloaded again.""" 36 | 37 | def __init__( 38 | self, 39 | root: str, 40 | split: str = "train", 41 | transform: Optional[Callable] = None, 42 | target_transform: Optional[Callable] = None, 43 | download: bool = False, 44 | ) -> None: 45 | 46 | try: 47 | import scipy.io as sio 48 | except ImportError: 49 | raise RuntimeError("Scipy is not found. This dataset needs to have scipy installed: pip install scipy") 50 | 51 | super().__init__(root, transform=transform, target_transform=target_transform) 52 | 53 | self._split = verify_str_arg(split, "split", ("train", "test")) 54 | self._base_folder = pathlib.Path(root) / "stanford_cars" 55 | devkit = self._base_folder / "devkit" 56 | 57 | if self._split == "train": 58 | self._annotations_mat_path = devkit / "cars_train_annos.mat" 59 | self._images_base_path = self._base_folder / "cars_train" 60 | else: 61 | self._annotations_mat_path = devkit / "cars_test_annos_withlabels.mat" 62 | self._images_base_path = self._base_folder / "cars_test" 63 | 64 | if download: 65 | self.download() 66 | 67 | if not self._check_exists(): 68 | raise RuntimeError("Dataset not found. You can use download=True to download it") 69 | 70 | self._samples = [ 71 | ( 72 | str(self._images_base_path / annotation["fname"]), 73 | annotation["class"] - 1, # Original target mapping starts from 1, hence -1 74 | ) 75 | for annotation in sio.loadmat(self._annotations_mat_path, squeeze_me=True)["annotations"] 76 | ] 77 | 78 | self.classes = sio.loadmat(str(devkit / "cars_meta.mat"), squeeze_me=True)["class_names"].tolist() 79 | self.class_to_idx = {cls: i for i, cls in enumerate(self.classes)} 80 | 81 | def __len__(self) -> int: 82 | return len(self._samples) 83 | 84 | def __getitem__(self, idx: int) -> Tuple[Any, Any]: 85 | """Returns pil_image and class_id for given index""" 86 | image_path, target = self._samples[idx] 87 | pil_image = Image.open(image_path).convert("RGB") 88 | 89 | if self.transform is not None: 90 | pil_image = self.transform(pil_image) 91 | if self.target_transform is not None: 92 | target = self.target_transform(target) 93 | return pil_image, target 94 | 95 | 96 | def download(self) -> None: 97 | if self._check_exists(): 98 | return 99 | 100 | download_and_extract_archive( 101 | url="https://ai.stanford.edu/~jkrause/cars/car_devkit.tgz", 102 | download_root=str(self._base_folder), 103 | md5="c3b158d763b6e2245038c8ad08e45376", 104 | ) 105 | if self._split == "train": 106 | download_and_extract_archive( 107 | url="https://ai.stanford.edu/~jkrause/car196/cars_train.tgz", 108 | download_root=str(self._base_folder), 109 | md5="065e5b463ae28d29e77c1b4b166cfe61", 110 | ) 111 | else: 112 | download_and_extract_archive( 113 | url="https://ai.stanford.edu/~jkrause/car196/cars_test.tgz", 114 | download_root=str(self._base_folder), 115 | md5="4ce7ebf6a94d07f1952d94dd34c4d501", 116 | ) 117 | download_url( 118 | url="https://ai.stanford.edu/~jkrause/car196/cars_test_annos_withlabels.mat", 119 | root=str(self._base_folder), 120 | md5="b0a2b23655a3edd16d84508592a98d10", 121 | ) 122 | 123 | def _check_exists(self) -> bool: 124 | if not (self._base_folder / "devkit").is_dir(): 125 | return False 126 | 127 | return self._annotations_mat_path.exists() and self._images_base_path.is_dir() 128 | 129 | 130 | class Cars: 131 | def __init__(self, 132 | preprocess, 133 | location=os.path.expanduser('~/data'), 134 | batch_size=32, 135 | num_workers=0): 136 | # Data loading code 137 | 138 | self.train_dataset = PytorchStanfordCars(location, 'train', preprocess, download=False) 139 | self.train_loader = torch.utils.data.DataLoader( 140 | self.train_dataset, 141 | shuffle=True, 142 | batch_size=batch_size, 143 | num_workers=num_workers, 144 | ) 145 | 146 | self.test_dataset = PytorchStanfordCars(location, 'test', preprocess, download=False) 147 | self.test_loader = torch.utils.data.DataLoader( 148 | self.test_dataset, 149 | batch_size=batch_size, 150 | num_workers=num_workers 151 | ) 152 | self.test_loader_shuffle = torch.utils.data.DataLoader( 153 | self.test_dataset, 154 | shuffle=True, 155 | batch_size=batch_size, 156 | num_workers=num_workers 157 | ) 158 | idx_to_class = dict((v, k) for k, v in self.train_dataset.class_to_idx.items()) 159 | self.classnames = [idx_to_class[i].replace( 160 | '_', ' ') for i in range(len(idx_to_class))] 161 | -------------------------------------------------------------------------------- /src/datasets/common.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import json 4 | import glob 5 | import collections 6 | import random 7 | 8 | import numpy as np 9 | 10 | from tqdm import tqdm 11 | 12 | import torchvision.datasets as datasets 13 | from torch.utils.data import Dataset, DataLoader, Sampler 14 | 15 | 16 | class SubsetSampler(Sampler): 17 | def __init__(self, indices): 18 | self.indices = indices 19 | 20 | def __iter__(self): 21 | return (i for i in self.indices) 22 | 23 | def __len__(self): 24 | return len(self.indices) 25 | 26 | class ImageFolderWithPaths(datasets.ImageFolder): 27 | def __init__(self, path, transform, flip_label_prob=0.0): 28 | super().__init__(path, transform) 29 | self.flip_label_prob = flip_label_prob 30 | if self.flip_label_prob > 0: 31 | print(f'Flipping labels with probability {self.flip_label_prob}') 32 | num_classes = len(self.classes) 33 | for i in range(len(self.samples)): 34 | if random.random() < self.flip_label_prob: 35 | new_label = random.randint(0, num_classes-1) 36 | self.samples[i] = ( 37 | self.samples[i][0], 38 | new_label 39 | ) 40 | 41 | def __getitem__(self, index): 42 | image, label = super(ImageFolderWithPaths, self).__getitem__(index) 43 | return { 44 | 'images': image, 45 | 'labels': label, 46 | 'image_paths': self.samples[index][0] 47 | } 48 | 49 | 50 | def maybe_dictionarize(batch): 51 | if isinstance(batch, dict): 52 | return batch 53 | 54 | if len(batch) == 2: 55 | batch = {'images': batch[0], 'labels': batch[1]} 56 | elif len(batch) == 3: 57 | batch = {'images': batch[0], 'labels': batch[1], 'metadata': batch[2]} 58 | else: 59 | raise ValueError(f'Unexpected number of elements: {len(batch)}') 60 | 61 | return batch 62 | 63 | 64 | def get_features_helper(image_encoder, dataloader, device): 65 | all_data = collections.defaultdict(list) 66 | 67 | image_encoder = image_encoder.to(device) 68 | image_encoder = torch.nn.DataParallel(image_encoder, device_ids=[x for x in range(torch.cuda.device_count())]) 69 | image_encoder.eval() 70 | 71 | with torch.no_grad(): 72 | for batch in tqdm(dataloader): 73 | batch = maybe_dictionarize(batch) 74 | features = image_encoder(batch['images'].cuda()) 75 | 76 | all_data['features'].append(features.cpu()) 77 | 78 | for key, val in batch.items(): 79 | if key == 'images': 80 | continue 81 | if hasattr(val, 'cpu'): 82 | val = val.cpu() 83 | all_data[key].append(val) 84 | else: 85 | all_data[key].extend(val) 86 | 87 | for key, val in all_data.items(): 88 | if torch.is_tensor(val[0]): 89 | all_data[key] = torch.cat(val).numpy() 90 | 91 | return all_data 92 | 93 | 94 | def get_features(is_train, image_encoder, dataset, device): 95 | split = 'train' if is_train else 'val' 96 | dname = type(dataset).__name__ 97 | if image_encoder.cache_dir is not None: 98 | cache_dir = f'{image_encoder.cache_dir}/{dname}/{split}' 99 | cached_files = glob.glob(f'{cache_dir}/*') 100 | if image_encoder.cache_dir is not None and len(cached_files) > 0: 101 | print(f'Getting features from {cache_dir}') 102 | data = {} 103 | for cached_file in cached_files: 104 | name = os.path.splitext(os.path.basename(cached_file))[0] 105 | data[name] = torch.load(cached_file) 106 | else: 107 | print(f'Did not find cached features at {cache_dir}. Building from scratch.') 108 | loader = dataset.train_loader if is_train else dataset.test_loader 109 | data = get_features_helper(image_encoder, loader, device) 110 | if image_encoder.cache_dir is None: 111 | print('Not caching because no cache directory was passed.') 112 | else: 113 | os.makedirs(cache_dir, exist_ok=True) 114 | print(f'Caching data at {cache_dir}') 115 | for name, val in data.items(): 116 | torch.save(val, f'{cache_dir}/{name}.pt') 117 | return data 118 | 119 | 120 | class FeatureDataset(Dataset): 121 | def __init__(self, is_train, image_encoder, dataset, device): 122 | self.data = get_features(is_train, image_encoder, dataset, device) 123 | 124 | def __len__(self): 125 | return len(self.data['features']) 126 | 127 | def __getitem__(self, idx): 128 | data = {k: v[idx] for k, v in self.data.items()} 129 | data['features'] = torch.from_numpy(data['features']).float() 130 | return data 131 | 132 | 133 | def get_dataloader(dataset, is_train, args, image_encoder=None): 134 | if image_encoder is not None: 135 | feature_dataset = FeatureDataset(is_train, image_encoder, dataset, args.device) 136 | dataloader = DataLoader(feature_dataset, batch_size=args.batch_size, shuffle=is_train) 137 | else: 138 | dataloader = dataset.train_loader if is_train else dataset.test_loader 139 | return dataloader 140 | 141 | def get_dataloader_shuffle(dataset): 142 | dataloader = dataset.test_loader_shuffle 143 | return dataloader -------------------------------------------------------------------------------- /src/datasets/dtd.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torchvision.datasets as datasets 4 | 5 | 6 | class DTD: 7 | def __init__(self, 8 | preprocess, 9 | location=os.path.expanduser('~/data'), 10 | batch_size=32, 11 | num_workers=0): 12 | # Data loading code 13 | traindir = os.path.join(location, 'dtd', 'train') 14 | valdir = os.path.join(location, 'dtd', 'test') 15 | 16 | self.train_dataset = datasets.ImageFolder( 17 | traindir, transform=preprocess) 18 | self.train_loader = torch.utils.data.DataLoader( 19 | self.train_dataset, 20 | shuffle=True, 21 | batch_size=batch_size, 22 | num_workers=num_workers, 23 | ) 24 | 25 | self.test_dataset = datasets.ImageFolder(valdir, transform=preprocess) 26 | self.test_loader = torch.utils.data.DataLoader( 27 | self.test_dataset, 28 | batch_size=batch_size, 29 | num_workers=num_workers 30 | ) 31 | 32 | self.test_loader_shuffle = torch.utils.data.DataLoader( 33 | self.test_dataset, 34 | shuffle=True, 35 | batch_size=batch_size, 36 | num_workers=num_workers 37 | ) 38 | 39 | idx_to_class = dict((v, k) 40 | for k, v in self.train_dataset.class_to_idx.items()) 41 | self.classnames = [idx_to_class[i].replace( 42 | '_', ' ') for i in range(len(idx_to_class))] -------------------------------------------------------------------------------- /src/datasets/eurosat.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torchvision.datasets as datasets 4 | import re 5 | 6 | def pretify_classname(classname): 7 | l = re.findall(r'[A-Z](?:[a-z]+|[A-Z]*(?=[A-Z]|$))', classname) 8 | l = [i.lower() for i in l] 9 | out = ' '.join(l) 10 | if out.endswith('al'): 11 | return out + ' area' 12 | return out 13 | 14 | class EuroSATBase: 15 | def __init__(self, 16 | preprocess, 17 | test_split, 18 | location='~/data', 19 | batch_size=32, 20 | num_workers=0): 21 | # Data loading code 22 | traindir = os.path.join(location, 'EuroSAT_splits', 'train') 23 | testdir = os.path.join(location, 'EuroSAT_splits', test_split) 24 | 25 | 26 | self.train_dataset = datasets.ImageFolder(traindir, transform=preprocess) 27 | self.train_loader = torch.utils.data.DataLoader( 28 | self.train_dataset, 29 | shuffle=True, 30 | batch_size=batch_size, 31 | num_workers=num_workers, 32 | ) 33 | 34 | self.test_dataset = datasets.ImageFolder(testdir, transform=preprocess) 35 | self.test_loader = torch.utils.data.DataLoader( 36 | self.test_dataset, 37 | batch_size=batch_size, 38 | num_workers=num_workers 39 | ) 40 | self.test_loader_shuffle = torch.utils.data.DataLoader( 41 | self.test_dataset, 42 | shuffle=True, 43 | batch_size=batch_size, 44 | num_workers=num_workers 45 | ) 46 | idx_to_class = dict((v, k) 47 | for k, v in self.train_dataset.class_to_idx.items()) 48 | 49 | self.classnames = [idx_to_class[i].replace('_', ' ') for i in range(len(idx_to_class))] 50 | self.classnames = [pretify_classname(c) for c in self.classnames] 51 | ours_to_open_ai = { 52 | 'annual crop': 'annual crop land', 53 | 'forest': 'forest', 54 | 'herbaceous vegetation': 'brushland or shrubland', 55 | 'highway': 'highway or road', 56 | 'industrial area': 'industrial buildings or commercial buildings', 57 | 'pasture': 'pasture land', 58 | 'permanent crop': 'permanent crop land', 59 | 'residential area': 'residential buildings or homes or apartments', 60 | 'river': 'river', 61 | 'sea lake': 'lake or sea', 62 | } 63 | for i in range(len(self.classnames)): 64 | self.classnames[i] = ours_to_open_ai[self.classnames[i]] 65 | 66 | 67 | class EuroSAT(EuroSATBase): 68 | def __init__(self, 69 | preprocess, 70 | location='~/data', 71 | batch_size=32, 72 | num_workers=0): 73 | super().__init__(preprocess, 'test', location, batch_size, num_workers) 74 | 75 | 76 | class EuroSATVal(EuroSATBase): 77 | def __init__(self, 78 | preprocess, 79 | location='~/data', 80 | batch_size=32, 81 | num_workers=0): 82 | super().__init__(preprocess, 'val', location, batch_size, num_workers) 83 | -------------------------------------------------------------------------------- /src/datasets/gtsrb.py: -------------------------------------------------------------------------------- 1 | import csv 2 | import os 3 | import pathlib 4 | from typing import Any, Callable, Dict, List, Optional, Tuple 5 | 6 | import numpy as np 7 | import PIL 8 | import torch 9 | from torchvision.datasets.folder import make_dataset 10 | from torchvision.datasets.utils import (download_and_extract_archive, verify_str_arg) 11 | from torchvision.datasets.vision import VisionDataset 12 | 13 | def find_classes(directory: str) -> Tuple[List[str], Dict[str, int]]: 14 | """Finds the class folders in a dataset. 15 | 16 | See :class:`DatasetFolder` for details. 17 | """ 18 | classes = sorted(entry.name for entry in os.scandir(directory) if entry.is_dir()) 19 | if not classes: 20 | raise FileNotFoundError(f"Couldn't find any class folder in {directory}.") 21 | 22 | class_to_idx = {cls_name: i for i, cls_name in enumerate(classes)} 23 | return classes, class_to_idx 24 | 25 | class PyTorchGTSRB(VisionDataset): 26 | """`German Traffic Sign Recognition Benchmark (GTSRB) `_ Dataset. 27 | 28 | Modified from https://pytorch.org/vision/main/_modules/torchvision/datasets/gtsrb.html#GTSRB. 29 | 30 | Args: 31 | root (string): Root directory of the dataset. 32 | split (string, optional): The dataset split, supports ``"train"`` (default), or ``"test"``. 33 | transform (callable, optional): A function/transform that takes in an PIL image and returns a transformed 34 | version. E.g, ``transforms.RandomCrop``. 35 | target_transform (callable, optional): A function/transform that takes in the target and transforms it. 36 | download (bool, optional): If True, downloads the dataset from the internet and 37 | puts it in root directory. If dataset is already downloaded, it is not 38 | downloaded again. 39 | """ 40 | 41 | def __init__( 42 | self, 43 | root: str, 44 | split: str = "train", 45 | transform: Optional[Callable] = None, 46 | target_transform: Optional[Callable] = None, 47 | download: bool = False, 48 | ) -> None: 49 | 50 | super().__init__(root, transform=transform, target_transform=target_transform) 51 | 52 | self._split = verify_str_arg(split, "split", ("train", "test")) 53 | self._base_folder = pathlib.Path(root) / "gtsrb" 54 | self._target_folder = ( 55 | self._base_folder / "GTSRB" / ("Training" if self._split == "train" else "Final_Test/Images") 56 | ) 57 | 58 | if download: 59 | self.download() 60 | 61 | if not self._check_exists(): 62 | raise RuntimeError("Dataset not found. You can use download=True to download it") 63 | 64 | if self._split == "train": 65 | _, class_to_idx = find_classes(str(self._target_folder)) 66 | samples = make_dataset(str(self._target_folder), extensions=(".ppm",), class_to_idx=class_to_idx) 67 | else: 68 | with open(self._base_folder / "GT-final_test.csv") as csv_file: 69 | samples = [ 70 | (str(self._target_folder / row["Filename"]), int(row["ClassId"])) 71 | for row in csv.DictReader(csv_file, delimiter=";", skipinitialspace=True) 72 | ] 73 | 74 | self._samples = samples 75 | self.transform = transform 76 | self.target_transform = target_transform 77 | 78 | def __len__(self) -> int: 79 | return len(self._samples) 80 | 81 | def __getitem__(self, index: int) -> Tuple[Any, Any]: 82 | 83 | path, target = self._samples[index] 84 | sample = PIL.Image.open(path).convert("RGB") 85 | 86 | if self.transform is not None: 87 | sample = self.transform(sample) 88 | 89 | if self.target_transform is not None: 90 | target = self.target_transform(target) 91 | 92 | return sample, target 93 | 94 | 95 | def _check_exists(self) -> bool: 96 | return self._target_folder.is_dir() 97 | 98 | def download(self) -> None: 99 | if self._check_exists(): 100 | return 101 | 102 | base_url = "https://sid.erda.dk/public/archives/daaeac0d7ce1152aea9b61d9f1e19370/" 103 | 104 | if self._split == "train": 105 | download_and_extract_archive( 106 | f"{base_url}GTSRB-Training_fixed.zip", 107 | download_root=str(self._base_folder), 108 | md5="513f3c79a4c5141765e10e952eaa2478", 109 | ) 110 | else: 111 | download_and_extract_archive( 112 | f"{base_url}GTSRB_Final_Test_Images.zip", 113 | download_root=str(self._base_folder), 114 | md5="c7e4e6327067d32654124b0fe9e82185", 115 | ) 116 | download_and_extract_archive( 117 | f"{base_url}GTSRB_Final_Test_GT.zip", 118 | download_root=str(self._base_folder), 119 | md5="fe31e9c9270bbcd7b84b7f21a9d9d9e5", 120 | ) 121 | 122 | 123 | class GTSRB: 124 | def __init__(self, 125 | preprocess, 126 | location=os.path.expanduser('~/data'), 127 | batch_size=128, 128 | num_workers=0): 129 | 130 | # to fit with repo conventions for location 131 | self.train_dataset = PyTorchGTSRB( 132 | root=location, 133 | download=False, 134 | split='train', 135 | transform=preprocess 136 | ) 137 | 138 | self.train_loader = torch.utils.data.DataLoader( 139 | self.train_dataset, 140 | batch_size=batch_size, 141 | shuffle=True, 142 | num_workers=num_workers 143 | ) 144 | 145 | self.test_dataset = PyTorchGTSRB( 146 | root=location, 147 | download=False, 148 | split='test', 149 | transform=preprocess 150 | ) 151 | 152 | self.test_loader = torch.utils.data.DataLoader( 153 | self.test_dataset, 154 | batch_size=batch_size, 155 | shuffle=False, 156 | num_workers=num_workers 157 | ) 158 | 159 | self.test_loader_shuffle = torch.utils.data.DataLoader( 160 | self.test_dataset, 161 | batch_size=batch_size, 162 | shuffle=True, 163 | num_workers=num_workers 164 | ) 165 | 166 | # from https://github.com/openai/CLIP/blob/e184f608c5d5e58165682f7c332c3a8b4c1545f2/data/prompts.md 167 | self.classnames = [ 168 | 'red and white circle 20 kph speed limit', 169 | 'red and white circle 30 kph speed limit', 170 | 'red and white circle 50 kph speed limit', 171 | 'red and white circle 60 kph speed limit', 172 | 'red and white circle 70 kph speed limit', 173 | 'red and white circle 80 kph speed limit', 174 | 'end / de-restriction of 80 kph speed limit', 175 | 'red and white circle 100 kph speed limit', 176 | 'red and white circle 120 kph speed limit', 177 | 'red and white circle red car and black car no passing', 178 | 'red and white circle red truck and black car no passing', 179 | 'red and white triangle road intersection warning', 180 | 'white and yellow diamond priority road', 181 | 'red and white upside down triangle yield right-of-way', 182 | 'stop', 183 | 'empty red and white circle', 184 | 'red and white circle no truck entry', 185 | 'red circle with white horizonal stripe no entry', 186 | 'red and white triangle with exclamation mark warning', 187 | 'red and white triangle with black left curve approaching warning', 188 | 'red and white triangle with black right curve approaching warning', 189 | 'red and white triangle with black double curve approaching warning', 190 | 'red and white triangle rough / bumpy road warning', 191 | 'red and white triangle car skidding / slipping warning', 192 | 'red and white triangle with merging / narrow lanes warning', 193 | 'red and white triangle with person digging / construction / road work warning', 194 | 'red and white triangle with traffic light approaching warning', 195 | 'red and white triangle with person walking warning', 196 | 'red and white triangle with child and person walking warning', 197 | 'red and white triangle with bicyle warning', 198 | 'red and white triangle with snowflake / ice warning', 199 | 'red and white triangle with deer warning', 200 | 'white circle with gray strike bar no speed limit', 201 | 'blue circle with white right turn arrow mandatory', 202 | 'blue circle with white left turn arrow mandatory', 203 | 'blue circle with white forward arrow mandatory', 204 | 'blue circle with white forward or right turn arrow mandatory', 205 | 'blue circle with white forward or left turn arrow mandatory', 206 | 'blue circle with white keep right arrow mandatory', 207 | 'blue circle with white keep left arrow mandatory', 208 | 'blue circle with white arrows indicating a traffic circle', 209 | 'white circle with gray strike bar indicating no passing for cars has ended', 210 | 'white circle with gray strike bar indicating no passing for trucks has ended', 211 | ] 212 | -------------------------------------------------------------------------------- /src/datasets/mnist.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torchvision.datasets as datasets 4 | 5 | class MNIST: 6 | def __init__(self, 7 | preprocess, 8 | location=os.path.expanduser('~/data'), 9 | batch_size=128, 10 | num_workers=0): 11 | 12 | 13 | self.train_dataset = datasets.MNIST( 14 | root=location, 15 | download=True, 16 | train=True, 17 | transform=preprocess 18 | ) 19 | 20 | self.train_loader = torch.utils.data.DataLoader( 21 | self.train_dataset, 22 | batch_size=batch_size, 23 | shuffle=True, 24 | num_workers=num_workers 25 | ) 26 | 27 | self.test_dataset = datasets.MNIST( 28 | root=location, 29 | download=True, 30 | train=False, 31 | transform=preprocess 32 | ) 33 | 34 | self.test_loader = torch.utils.data.DataLoader( 35 | self.test_dataset, 36 | batch_size=batch_size, 37 | shuffle=False, 38 | num_workers=num_workers 39 | ) 40 | 41 | self.test_loader_shuffle = torch.utils.data.DataLoader( 42 | self.test_dataset, 43 | batch_size=batch_size, 44 | shuffle=True, 45 | num_workers=num_workers 46 | ) 47 | 48 | self.classnames = ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9'] -------------------------------------------------------------------------------- /src/datasets/registry.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import inspect 3 | import random 4 | import torch 5 | import copy 6 | 7 | from torch.utils.data.dataset import random_split 8 | 9 | from datasets.cars import Cars 10 | from datasets.dtd import DTD 11 | from datasets.eurosat import EuroSAT, EuroSATVal 12 | from datasets.gtsrb import GTSRB 13 | from datasets.mnist import MNIST 14 | from datasets.resisc45 import RESISC45 15 | from datasets.svhn import SVHN 16 | from datasets.sun397 import SUN397 17 | 18 | registry = { 19 | name: obj for name, obj in inspect.getmembers(sys.modules[__name__], inspect.isclass) 20 | } 21 | 22 | 23 | class GenericDataset(object): 24 | def __init__(self): 25 | self.train_dataset = None 26 | self.train_loader = None 27 | self.test_dataset = None 28 | self.test_loader = None 29 | self.classnames = None 30 | 31 | 32 | def split_train_into_train_val(dataset, new_dataset_class_name, batch_size, num_workers, val_fraction, max_val_samples=None, seed=0): 33 | assert val_fraction > 0. and val_fraction < 1. 34 | total_size = len(dataset.train_dataset) 35 | val_size = int(total_size * val_fraction) 36 | if max_val_samples is not None: 37 | val_size = min(val_size, max_val_samples) 38 | train_size = total_size - val_size 39 | 40 | assert val_size > 0 41 | assert train_size > 0 42 | 43 | lengths = [train_size, val_size] 44 | 45 | trainset, valset = random_split( 46 | dataset.train_dataset, 47 | lengths, 48 | generator=torch.Generator().manual_seed(seed) 49 | ) 50 | if new_dataset_class_name == 'MNISTVal': 51 | assert trainset.indices[0] == 36044 52 | 53 | 54 | new_dataset = None 55 | 56 | new_dataset_class = type(new_dataset_class_name, (GenericDataset, ), {}) 57 | new_dataset = new_dataset_class() 58 | 59 | new_dataset.train_dataset = trainset 60 | new_dataset.train_loader = torch.utils.data.DataLoader( 61 | new_dataset.train_dataset, 62 | shuffle=True, 63 | batch_size=batch_size, 64 | num_workers=num_workers, 65 | ) 66 | 67 | new_dataset.test_dataset = valset 68 | new_dataset.test_loader = torch.utils.data.DataLoader( 69 | new_dataset.test_dataset, 70 | batch_size=batch_size, 71 | num_workers=num_workers 72 | ) 73 | 74 | new_dataset.test_loader_shuffle = torch.utils.data.DataLoader( 75 | new_dataset.test_dataset, 76 | shuffle=True, 77 | batch_size=batch_size, 78 | num_workers=num_workers 79 | ) 80 | 81 | new_dataset.classnames = copy.copy(dataset.classnames) 82 | 83 | return new_dataset 84 | 85 | 86 | def get_dataset(dataset_name, preprocess, location, batch_size=128, num_workers=0, val_fraction=0.1, max_val_samples=5000): 87 | if dataset_name.endswith('Val'): 88 | # Handle val splits 89 | if dataset_name in registry: 90 | dataset_class = registry[dataset_name] 91 | else: 92 | base_dataset_name = dataset_name.split('Val')[0] 93 | base_dataset = get_dataset(base_dataset_name, preprocess, location, batch_size, num_workers) 94 | dataset = split_train_into_train_val( 95 | base_dataset, dataset_name, batch_size, num_workers, val_fraction, max_val_samples) 96 | return dataset 97 | else: 98 | assert dataset_name in registry, f'Unsupported dataset: {dataset_name}. Supported datasets: {list(registry.keys())}' 99 | dataset_class = registry[dataset_name] 100 | dataset = dataset_class( 101 | preprocess, location=location, batch_size=batch_size, num_workers=num_workers 102 | ) 103 | return dataset 104 | -------------------------------------------------------------------------------- /src/datasets/resisc45.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | 4 | import abc 5 | import os 6 | from typing import Any, Callable, Dict, Optional, Tuple 7 | 8 | import numpy as np 9 | import torch 10 | from torch import Tensor 11 | from torch.utils.data import Dataset 12 | from torchvision.datasets import ImageFolder 13 | from torchvision.datasets.folder import default_loader as pil_loader 14 | 15 | 16 | # modified from: https://github.com/microsoft/torchgeo 17 | class VisionDataset(Dataset[Dict[str, Any]], abc.ABC): 18 | """Abstract base class for datasets lacking geospatial information. 19 | This base class is designed for datasets with pre-defined image chips. 20 | """ 21 | 22 | @abc.abstractmethod 23 | def __getitem__(self, index: int) -> Dict[str, Any]: 24 | """Return an index within the dataset. 25 | Args: 26 | index: index to return 27 | Returns: 28 | data and labels at that index 29 | Raises: 30 | IndexError: if index is out of range of the dataset 31 | """ 32 | 33 | @abc.abstractmethod 34 | def __len__(self) -> int: 35 | """Return the length of the dataset. 36 | Returns: 37 | length of the dataset 38 | """ 39 | 40 | def __str__(self) -> str: 41 | """Return the informal string representation of the object. 42 | Returns: 43 | informal string representation 44 | """ 45 | return f"""\ 46 | {self.__class__.__name__} Dataset 47 | type: VisionDataset 48 | size: {len(self)}""" 49 | 50 | 51 | class VisionClassificationDataset(VisionDataset, ImageFolder): 52 | """Abstract base class for classification datasets lacking geospatial information. 53 | This base class is designed for datasets with pre-defined image chips which 54 | are separated into separate folders per class. 55 | """ 56 | 57 | def __init__( 58 | self, 59 | root: str, 60 | transforms: Optional[Callable[[Dict[str, Tensor]], Dict[str, Tensor]]] = None, 61 | loader: Optional[Callable[[str], Any]] = pil_loader, 62 | is_valid_file: Optional[Callable[[str], bool]] = None, 63 | ) -> None: 64 | """Initialize a new VisionClassificationDataset instance. 65 | Args: 66 | root: root directory where dataset can be found 67 | transforms: a function/transform that takes input sample and its target as 68 | entry and returns a transformed version 69 | loader: a callable function which takes as input a path to an image and 70 | returns a PIL Image or numpy array 71 | is_valid_file: A function that takes the path of an Image file and checks if 72 | the file is a valid file 73 | """ 74 | # When transform & target_transform are None, ImageFolder.__getitem__(index) 75 | # returns a PIL.Image and int for image and label, respectively 76 | super().__init__( 77 | root=root, 78 | transform=None, 79 | target_transform=None, 80 | loader=loader, 81 | is_valid_file=is_valid_file, 82 | ) 83 | 84 | # Must be set after calling super().__init__() 85 | self.transforms = transforms 86 | 87 | def __getitem__(self, index: int) -> Dict[str, Tensor]: 88 | """Return an index within the dataset. 89 | Args: 90 | index: index to return 91 | Returns: 92 | data and label at that index 93 | """ 94 | image, label = self._load_image(index) 95 | 96 | if self.transforms is not None: 97 | return self.transforms(image), label 98 | 99 | return image, label 100 | 101 | def __len__(self) -> int: 102 | """Return the number of data points in the dataset. 103 | Returns: 104 | length of the dataset 105 | """ 106 | return len(self.imgs) 107 | 108 | def _load_image(self, index: int) -> Tuple[Tensor, Tensor]: 109 | """Load a single image and it's class label. 110 | Args: 111 | index: index to return 112 | Returns: 113 | the image 114 | the image class label 115 | """ 116 | img, label = ImageFolder.__getitem__(self, index) 117 | label = torch.tensor(label) 118 | return img, label 119 | 120 | 121 | class RESISC45Dataset(VisionClassificationDataset): 122 | """RESISC45 dataset. 123 | The `RESISC45 `_ 124 | dataset is a dataset for remote sensing image scene classification. 125 | Dataset features: 126 | * 31,500 images with 0.2-30 m per pixel resolution (256x256 px) 127 | * three spectral bands - RGB 128 | * 45 scene classes, 700 images per class 129 | * images extracted from Google Earth from over 100 countries 130 | * images conditions with high variability (resolution, weather, illumination) 131 | Dataset format: 132 | * images are three-channel jpgs 133 | Dataset classes: 134 | 0. airplane 135 | 1. airport 136 | 2. baseball_diamond 137 | 3. basketball_court 138 | 4. beach 139 | 5. bridge 140 | 6. chaparral 141 | 7. church 142 | 8. circular_farmland 143 | 9. cloud 144 | 10. commercial_area 145 | 11. dense_residential 146 | 12. desert 147 | 13. forest 148 | 14. freeway 149 | 15. golf_course 150 | 16. ground_track_field 151 | 17. harbor 152 | 18. industrial_area 153 | 19. intersection 154 | 20. island 155 | 21. lake 156 | 22. meadow 157 | 23. medium_residential 158 | 24. mobile_home_park 159 | 25. mountain 160 | 26. overpass 161 | 27. palace 162 | 28. parking_lot 163 | 29. railway 164 | 30. railway_station 165 | 31. rectangular_farmland 166 | 32. river 167 | 33. roundabout 168 | 34. runway 169 | 35. sea_ice 170 | 36. ship 171 | 37. snowberg 172 | 38. sparse_residential 173 | 39. stadium 174 | 40. storage_tank 175 | 41. tennis_court 176 | 42. terrace 177 | 43. thermal_power_station 178 | 44. wetland 179 | This dataset uses the train/val/test splits defined in the "In-domain representation 180 | learning for remote sensing" paper: 181 | * https://arxiv.org/abs/1911.06721 182 | If you use this dataset in your research, please cite the following paper: 183 | * https://doi.org/10.1109/jproc.2017.2675998 184 | """ 185 | 186 | # url = "https://drive.google.com/file/d/1DnPSU5nVSN7xv95bpZ3XQ0JhKXZOKgIv" 187 | # md5 = "d824acb73957502b00efd559fc6cfbbb" 188 | # filename = "NWPU-RESISC45.rar" 189 | directory = "resisc45/NWPU-RESISC45" 190 | 191 | splits = ["train", "val", "test"] 192 | split_urls = { 193 | "train": "https://storage.googleapis.com/remote_sensing_representations/resisc45-train.txt", # noqa: E501 194 | "val": "https://storage.googleapis.com/remote_sensing_representations/resisc45-val.txt", # noqa: E501 195 | "test": "https://storage.googleapis.com/remote_sensing_representations/resisc45-test.txt", # noqa: E501 196 | } 197 | split_md5s = { 198 | "train": "b5a4c05a37de15e4ca886696a85c403e", 199 | "val": "a0770cee4c5ca20b8c32bbd61e114805", 200 | "test": "3dda9e4988b47eb1de9f07993653eb08", 201 | } 202 | classes = [ 203 | "airplane", 204 | "airport", 205 | "baseball_diamond", 206 | "basketball_court", 207 | "beach", 208 | "bridge", 209 | "chaparral", 210 | "church", 211 | "circular_farmland", 212 | "cloud", 213 | "commercial_area", 214 | "dense_residential", 215 | "desert", 216 | "forest", 217 | "freeway", 218 | "golf_course", 219 | "ground_track_field", 220 | "harbor", 221 | "industrial_area", 222 | "intersection", 223 | "island", 224 | "lake", 225 | "meadow", 226 | "medium_residential", 227 | "mobile_home_park", 228 | "mountain", 229 | "overpass", 230 | "palace", 231 | "parking_lot", 232 | "railway", 233 | "railway_station", 234 | "rectangular_farmland", 235 | "river", 236 | "roundabout", 237 | "runway", 238 | "sea_ice", 239 | "ship", 240 | "snowberg", 241 | "sparse_residential", 242 | "stadium", 243 | "storage_tank", 244 | "tennis_court", 245 | "terrace", 246 | "thermal_power_station", 247 | "wetland", 248 | ] 249 | 250 | def __init__( 251 | self, 252 | root: str = "data", 253 | split: str = "train", 254 | transforms: Optional[Callable[[Dict[str, Tensor]], Dict[str, Tensor]]] = None, 255 | ) -> None: 256 | """Initialize a new RESISC45 dataset instance. 257 | Args: 258 | root: root directory where dataset can be found 259 | split: one of "train", "val", or "test" 260 | transforms: a function/transform that takes input sample and its target as 261 | entry and returns a transformed version 262 | """ 263 | assert split in self.splits 264 | self.root = root 265 | 266 | valid_fns = set() 267 | with open(os.path.join(self.root, "resisc45", f"resisc45-{split}.txt")) as f: 268 | for fn in f: 269 | valid_fns.add(fn.strip()) 270 | is_in_split: Callable[[str], bool] = lambda x: os.path.basename( 271 | x) in valid_fns 272 | 273 | super().__init__( 274 | root=os.path.join(root, self.directory), 275 | transforms=transforms, 276 | is_valid_file=is_in_split, 277 | ) 278 | 279 | 280 | 281 | class RESISC45: 282 | def __init__(self, 283 | preprocess, 284 | location=os.path.expanduser('~/data'), 285 | batch_size=32, 286 | num_workers=0): 287 | 288 | self.train_dataset = RESISC45Dataset(root=location, split='train', transforms=preprocess) 289 | self.train_loader = torch.utils.data.DataLoader( 290 | self.train_dataset, 291 | shuffle=True, 292 | batch_size=batch_size, 293 | num_workers=num_workers, 294 | ) 295 | 296 | self.test_dataset = RESISC45Dataset(root=location, split='test', transforms=preprocess) 297 | self.test_loader = torch.utils.data.DataLoader( 298 | self.test_dataset, 299 | batch_size=batch_size, 300 | num_workers=num_workers 301 | ) 302 | self.test_loader_shuffle = torch.utils.data.DataLoader( 303 | self.test_dataset, 304 | shuffle=True, 305 | batch_size=batch_size, 306 | num_workers=num_workers 307 | ) 308 | 309 | # class names have _ so split on this for better zero-shot head 310 | self.classnames = [' '.join(c.split('_')) for c in RESISC45Dataset.classes] 311 | -------------------------------------------------------------------------------- /src/datasets/sun397.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torchvision.datasets as datasets 4 | 5 | class SUN397: 6 | def __init__(self, 7 | preprocess, 8 | location=os.path.expanduser('~/data'), 9 | batch_size=32, 10 | num_workers=0): 11 | # Data loading code 12 | traindir = os.path.join(location, 'sun397', 'train') 13 | valdir = os.path.join(location, 'sun397', 'test') 14 | 15 | 16 | self.train_dataset = datasets.ImageFolder(traindir, transform=preprocess) 17 | self.train_loader = torch.utils.data.DataLoader( 18 | self.train_dataset, 19 | shuffle=True, 20 | batch_size=batch_size, 21 | num_workers=num_workers, 22 | ) 23 | 24 | self.test_dataset = datasets.ImageFolder(valdir, transform=preprocess) 25 | self.test_loader = torch.utils.data.DataLoader( 26 | self.test_dataset, 27 | batch_size=batch_size, 28 | num_workers=num_workers 29 | ) 30 | self.test_loader_shuffle = torch.utils.data.DataLoader( 31 | self.test_dataset, 32 | shuffle=True, 33 | batch_size=batch_size, 34 | num_workers=num_workers 35 | ) 36 | idx_to_class = dict((v, k) 37 | for k, v in self.train_dataset.class_to_idx.items()) 38 | self.classnames = [idx_to_class[i][2:].replace('_', ' ') for i in range(len(idx_to_class))] 39 | -------------------------------------------------------------------------------- /src/datasets/svhn.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | from torchvision.datasets import SVHN as PyTorchSVHN 4 | import numpy as np 5 | 6 | 7 | class SVHN: 8 | def __init__(self, 9 | preprocess, 10 | location=os.path.expanduser('~/data'), 11 | batch_size=128, 12 | num_workers=0): 13 | 14 | # to fit with repo conventions for location 15 | modified_location = os.path.join(location, 'svhn') 16 | 17 | self.train_dataset = PyTorchSVHN( 18 | root=modified_location, 19 | download=True, 20 | split='train', 21 | transform=preprocess 22 | ) 23 | 24 | self.train_loader = torch.utils.data.DataLoader( 25 | self.train_dataset, 26 | batch_size=batch_size, 27 | shuffle=True, 28 | num_workers=num_workers 29 | ) 30 | 31 | self.test_dataset = PyTorchSVHN( 32 | root=modified_location, 33 | download=True, 34 | split='test', 35 | transform=preprocess 36 | ) 37 | 38 | self.test_loader = torch.utils.data.DataLoader( 39 | self.test_dataset, 40 | batch_size=batch_size, 41 | shuffle=False, 42 | num_workers=num_workers 43 | ) 44 | self.test_loader_shuffle = torch.utils.data.DataLoader( 45 | self.test_dataset, 46 | batch_size=batch_size, 47 | shuffle=True, 48 | num_workers=num_workers 49 | ) 50 | 51 | self.classnames = ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9'] 52 | -------------------------------------------------------------------------------- /src/datasets/templates.py: -------------------------------------------------------------------------------- 1 | cars_template = [ 2 | lambda c: f'a photo of a {c}.', 3 | lambda c: f'a photo of the {c}.', 4 | lambda c: f'a photo of my {c}.', 5 | lambda c: f'i love my {c}!', 6 | lambda c: f'a photo of my dirty {c}.', 7 | lambda c: f'a photo of my clean {c}.', 8 | lambda c: f'a photo of my new {c}.', 9 | lambda c: f'a photo of my old {c}.', 10 | ] 11 | 12 | cifar10_template = [ 13 | lambda c: f'a photo of a {c}.', 14 | lambda c: f'a blurry photo of a {c}.', 15 | lambda c: f'a black and white photo of a {c}.', 16 | lambda c: f'a low contrast photo of a {c}.', 17 | lambda c: f'a high contrast photo of a {c}.', 18 | lambda c: f'a bad photo of a {c}.', 19 | lambda c: f'a good photo of a {c}.', 20 | lambda c: f'a photo of a small {c}.', 21 | lambda c: f'a photo of a big {c}.', 22 | lambda c: f'a photo of the {c}.', 23 | lambda c: f'a blurry photo of the {c}.', 24 | lambda c: f'a black and white photo of the {c}.', 25 | lambda c: f'a low contrast photo of the {c}.', 26 | lambda c: f'a high contrast photo of the {c}.', 27 | lambda c: f'a bad photo of the {c}.', 28 | lambda c: f'a good photo of the {c}.', 29 | lambda c: f'a photo of the small {c}.', 30 | lambda c: f'a photo of the big {c}.', 31 | ] 32 | 33 | cifar100_template = [ 34 | lambda c: f'a photo of a {c}.', 35 | lambda c: f'a blurry photo of a {c}.', 36 | lambda c: f'a black and white photo of a {c}.', 37 | lambda c: f'a low contrast photo of a {c}.', 38 | lambda c: f'a high contrast photo of a {c}.', 39 | lambda c: f'a bad photo of a {c}.', 40 | lambda c: f'a good photo of a {c}.', 41 | lambda c: f'a photo of a small {c}.', 42 | lambda c: f'a photo of a big {c}.', 43 | lambda c: f'a photo of the {c}.', 44 | lambda c: f'a blurry photo of the {c}.', 45 | lambda c: f'a black and white photo of the {c}.', 46 | lambda c: f'a low contrast photo of the {c}.', 47 | lambda c: f'a high contrast photo of the {c}.', 48 | lambda c: f'a bad photo of the {c}.', 49 | lambda c: f'a good photo of the {c}.', 50 | lambda c: f'a photo of the small {c}.', 51 | lambda c: f'a photo of the big {c}.', 52 | ] 53 | 54 | dtd_template = [ 55 | lambda c: f'a photo of a {c} texture.', 56 | lambda c: f'a photo of a {c} pattern.', 57 | lambda c: f'a photo of a {c} thing.', 58 | lambda c: f'a photo of a {c} object.', 59 | lambda c: f'a photo of the {c} texture.', 60 | lambda c: f'a photo of the {c} pattern.', 61 | lambda c: f'a photo of the {c} thing.', 62 | lambda c: f'a photo of the {c} object.', 63 | ] 64 | 65 | eurosat_template = [ 66 | lambda c: f'a centered satellite photo of {c}.', 67 | lambda c: f'a centered satellite photo of a {c}.', 68 | lambda c: f'a centered satellite photo of the {c}.', 69 | ] 70 | 71 | food101_template = [ 72 | lambda c: f'a photo of {c}, a type of food.', 73 | ] 74 | 75 | gtsrb_template = [ 76 | lambda c: f'a zoomed in photo of a "{c}" traffic sign.', 77 | lambda c: f'a centered photo of a "{c}" traffic sign.', 78 | lambda c: f'a close up photo of a "{c}" traffic sign.', 79 | ] 80 | 81 | mnist_template = [ 82 | lambda c: f'a photo of the number: "{c}".', 83 | ] 84 | 85 | imagenet_template = [ 86 | lambda c: f'a bad photo of a {c}.', 87 | lambda c: f'a photo of many {c}.', 88 | lambda c: f'a sculpture of a {c}.', 89 | lambda c: f'a photo of the hard to see {c}.', 90 | lambda c: f'a low resolution photo of the {c}.', 91 | lambda c: f'a rendering of a {c}.', 92 | lambda c: f'graffiti of a {c}.', 93 | lambda c: f'a bad photo of the {c}.', 94 | lambda c: f'a cropped photo of the {c}.', 95 | lambda c: f'a tattoo of a {c}.', 96 | lambda c: f'the embroidered {c}.', 97 | lambda c: f'a photo of a hard to see {c}.', 98 | lambda c: f'a bright photo of a {c}.', 99 | lambda c: f'a photo of a clean {c}.', 100 | lambda c: f'a photo of a dirty {c}.', 101 | lambda c: f'a dark photo of the {c}.', 102 | lambda c: f'a drawing of a {c}.', 103 | lambda c: f'a photo of my {c}.', 104 | lambda c: f'the plastic {c}.', 105 | lambda c: f'a photo of the cool {c}.', 106 | lambda c: f'a close-up photo of a {c}.', 107 | lambda c: f'a black and white photo of the {c}.', 108 | lambda c: f'a painting of the {c}.', 109 | lambda c: f'a painting of a {c}.', 110 | lambda c: f'a pixelated photo of the {c}.', 111 | lambda c: f'a sculpture of the {c}.', 112 | lambda c: f'a bright photo of the {c}.', 113 | lambda c: f'a cropped photo of a {c}.', 114 | lambda c: f'a plastic {c}.', 115 | lambda c: f'a photo of the dirty {c}.', 116 | lambda c: f'a jpeg corrupted photo of a {c}.', 117 | lambda c: f'a blurry photo of the {c}.', 118 | lambda c: f'a photo of the {c}.', 119 | lambda c: f'a good photo of the {c}.', 120 | lambda c: f'a rendering of the {c}.', 121 | lambda c: f'a {c} in a video game.', 122 | lambda c: f'a photo of one {c}.', 123 | lambda c: f'a doodle of a {c}.', 124 | lambda c: f'a close-up photo of the {c}.', 125 | lambda c: f'a photo of a {c}.', 126 | lambda c: f'the origami {c}.', 127 | lambda c: f'the {c} in a video game.', 128 | lambda c: f'a sketch of a {c}.', 129 | lambda c: f'a doodle of the {c}.', 130 | lambda c: f'a origami {c}.', 131 | lambda c: f'a low resolution photo of a {c}.', 132 | lambda c: f'the toy {c}.', 133 | lambda c: f'a rendition of the {c}.', 134 | lambda c: f'a photo of the clean {c}.', 135 | lambda c: f'a photo of a large {c}.', 136 | lambda c: f'a rendition of a {c}.', 137 | lambda c: f'a photo of a nice {c}.', 138 | lambda c: f'a photo of a weird {c}.', 139 | lambda c: f'a blurry photo of a {c}.', 140 | lambda c: f'a cartoon {c}.', 141 | lambda c: f'art of a {c}.', 142 | lambda c: f'a sketch of the {c}.', 143 | lambda c: f'a embroidered {c}.', 144 | lambda c: f'a pixelated photo of a {c}.', 145 | lambda c: f'itap of the {c}.', 146 | lambda c: f'a jpeg corrupted photo of the {c}.', 147 | lambda c: f'a good photo of a {c}.', 148 | lambda c: f'a plushie {c}.', 149 | lambda c: f'a photo of the nice {c}.', 150 | lambda c: f'a photo of the small {c}.', 151 | lambda c: f'a photo of the weird {c}.', 152 | lambda c: f'the cartoon {c}.', 153 | lambda c: f'art of the {c}.', 154 | lambda c: f'a drawing of the {c}.', 155 | lambda c: f'a photo of the large {c}.', 156 | lambda c: f'a black and white photo of a {c}.', 157 | lambda c: f'the plushie {c}.', 158 | lambda c: f'a dark photo of a {c}.', 159 | lambda c: f'itap of a {c}.', 160 | lambda c: f'graffiti of the {c}.', 161 | lambda c: f'a toy {c}.', 162 | lambda c: f'itap of my {c}.', 163 | lambda c: f'a photo of a cool {c}.', 164 | lambda c: f'a photo of a small {c}.', 165 | lambda c: f'a tattoo of the {c}.', 166 | ] 167 | 168 | resisc45_template = [ 169 | lambda c: f'satellite imagery of {c}.', 170 | lambda c: f'aerial imagery of {c}.', 171 | lambda c: f'satellite photo of {c}.', 172 | lambda c: f'aerial photo of {c}.', 173 | lambda c: f'satellite view of {c}.', 174 | lambda c: f'aerial view of {c}.', 175 | lambda c: f'satellite imagery of a {c}.', 176 | lambda c: f'aerial imagery of a {c}.', 177 | lambda c: f'satellite photo of a {c}.', 178 | lambda c: f'aerial photo of a {c}.', 179 | lambda c: f'satellite view of a {c}.', 180 | lambda c: f'aerial view of a {c}.', 181 | lambda c: f'satellite imagery of the {c}.', 182 | lambda c: f'aerial imagery of the {c}.', 183 | lambda c: f'satellite photo of the {c}.', 184 | lambda c: f'aerial photo of the {c}.', 185 | lambda c: f'satellite view of the {c}.', 186 | lambda c: f'aerial view of the {c}.', 187 | ] 188 | 189 | stl10_template = [ 190 | lambda c: f'a photo of a {c}.', 191 | lambda c: f'a photo of the {c}.', 192 | ] 193 | 194 | sun397_template = [ 195 | lambda c: f'a photo of a {c}.', 196 | lambda c: f'a photo of the {c}.', 197 | ] 198 | 199 | svhn_template = [ 200 | lambda c: f'a photo of the number: "{c}".', 201 | ] 202 | 203 | 204 | dataset_to_template = { 205 | 'Cars': cars_template, 206 | 'CIFAR10': cifar10_template, 207 | 'CIFAR100': cifar100_template, 208 | 'DTD': dtd_template, 209 | 'EuroSAT': eurosat_template, 210 | 'Food101': food101_template, 211 | 'GTSRB': gtsrb_template, 212 | 'MNIST': mnist_template, 213 | 'ImageNet': imagenet_template, 214 | 'RESISC45': resisc45_template, 215 | 'STL10': stl10_template, 216 | 'SUN397': sun397_template, 217 | 'SVHN': svhn_template, 218 | } 219 | 220 | 221 | def get_templates(dataset_name): 222 | if dataset_name.endswith('Val'): 223 | return get_templates(dataset_name.replace('Val', '')) 224 | assert dataset_name in dataset_to_template, f'Unsupported dataset: {dataset_name}' 225 | return dataset_to_template[dataset_name] -------------------------------------------------------------------------------- /src/eval.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import tqdm 4 | 5 | import torch 6 | import numpy as np 7 | 8 | import utils 9 | from datasets.common import get_dataloader, maybe_dictionarize 10 | from heads import get_classification_head 11 | from modeling import ImageClassifier 12 | 13 | from datasets.registry import get_dataset 14 | 15 | def eval_single_dataset(image_encoder, dataset_name, args): 16 | classification_head = get_classification_head(args, dataset_name) 17 | model = ImageClassifier(image_encoder, classification_head) 18 | 19 | model.eval() 20 | 21 | dataset = get_dataset( 22 | dataset_name, 23 | model.val_preprocess, 24 | location=args.data_location, 25 | batch_size=args.batch_size 26 | ) 27 | dataloader = get_dataloader( 28 | dataset, is_train=False, args=args, image_encoder=None) 29 | device = args.device 30 | 31 | with torch.no_grad(): 32 | top1, correct, n = 0., 0., 0. 33 | for i, data in enumerate(tqdm.tqdm(dataloader)): 34 | data = maybe_dictionarize(data) 35 | x = data['images'].to(device) 36 | y = data['labels'].to(device) 37 | 38 | logits = utils.get_logits(x, model) 39 | 40 | pred = logits.argmax(dim=1, keepdim=True).to(device) 41 | 42 | correct += pred.eq(y.view_as(pred)).sum().item() 43 | 44 | n += y.size(0) 45 | 46 | top1 = correct / n 47 | 48 | metrics = {'top1': top1} 49 | print(f'Done evaluating on {dataset_name}. Accuracy: {100*top1:.2f}%') 50 | 51 | return metrics 52 | 53 | def eval_single_dataset_head(image_encoder, head, dataset_name, args): 54 | model = ImageClassifier(image_encoder, head) 55 | 56 | model.eval() 57 | 58 | dataset = get_dataset(dataset_name, model.val_preprocess, location=args.data_location, batch_size=args.batch_size) 59 | dataloader = get_dataloader(dataset, is_train=False, args=args, image_encoder=None) 60 | device = args.device 61 | 62 | with torch.no_grad(): 63 | top1, correct, n = 0., 0., 0. 64 | for i, data in enumerate(tqdm.tqdm(dataloader)): 65 | data = maybe_dictionarize(data) 66 | x = data['images'].to(device) 67 | y = data['labels'].to(device) 68 | 69 | logits = utils.get_logits(x, model) 70 | 71 | pred = logits.argmax(dim=1, keepdim=True).to(device) 72 | 73 | correct += pred.eq(y.view_as(pred)).sum().item() 74 | 75 | n += y.size(0) 76 | 77 | top1 = correct / n 78 | 79 | metrics = {'top1': top1} 80 | print(f'Done evaluating on {dataset_name}. Accuracy: {100 * top1:.2f}%') 81 | 82 | return metrics 83 | 84 | def eval_single_dataset_preprocess_head(image_encoder, head, dataset_name, args): 85 | model = ImageClassifier(image_encoder, head) 86 | 87 | model.eval() 88 | 89 | dataset = get_dataset(dataset_name, model.val_preprocess, location=args.data_location, batch_size=args.batch_size) 90 | dataloader = get_dataloader(dataset, is_train=False, args=args, image_encoder=None) 91 | device = args.device 92 | 93 | with torch.no_grad(): 94 | top1, correct, n = 0., 0., 0. 95 | for i, data in enumerate(tqdm.tqdm(dataloader)): 96 | data = maybe_dictionarize(data) 97 | x = data['images'].to(device) 98 | y = data['labels'].to(device) 99 | 100 | logits = utils.get_logits(x, model) 101 | 102 | pred = logits.argmax(dim=1, keepdim=True).to(device) 103 | 104 | correct += pred.eq(y.view_as(pred)).sum().item() 105 | 106 | n += y.size(0) 107 | 108 | top1 = correct / n 109 | 110 | metrics = {'top1': top1} 111 | print(f'Done evaluating on {dataset_name}. Accuracy: {100 * top1:.2f}%') 112 | 113 | return metrics 114 | 115 | def evaluate(image_encoder, args): 116 | if args.eval_datasets is None: 117 | return 118 | info = vars(args) 119 | for i, dataset_name in enumerate(args.eval_datasets): 120 | print('Evaluating on', dataset_name) 121 | 122 | results = eval_single_dataset(image_encoder, dataset_name, args) 123 | 124 | if 'top1' in results: 125 | print(f"{dataset_name} Top-1 accuracy: {results['top1']:.4f}") 126 | for key, val in results.items(): 127 | if 'worst' in key or 'f1' in key.lower() or 'pm0' in key: 128 | print(f"{dataset_name} {key}: {val:.4f}") 129 | info[dataset_name + ':' + key] = val 130 | 131 | if args.results_db is not None: 132 | dirname = os.path.dirname(args.results_db) 133 | if dirname: 134 | os.makedirs(dirname, exist_ok=True) 135 | with open(args.results_db, 'a+') as f: 136 | f.write(json.dumps(info) + '\n') 137 | print(f'Results saved to {args.results_db}.') 138 | else: 139 | print('Results not saved (to do so, use --results_db to specify a path).') 140 | 141 | return info -------------------------------------------------------------------------------- /src/heads.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | from tqdm import tqdm 4 | 5 | import open_clip 6 | 7 | from datasets.templates import get_templates 8 | from datasets.registry import get_dataset 9 | 10 | from modeling import ClassificationHead, ImageEncoder 11 | 12 | 13 | def build_classification_head(model, dataset_name, template, data_location, device): 14 | template = get_templates(dataset_name) 15 | 16 | logit_scale = model.logit_scale 17 | dataset = get_dataset( 18 | dataset_name, 19 | None, 20 | location=data_location 21 | ) 22 | model.eval() 23 | model.to(device) 24 | 25 | print('Building classification head.') 26 | with torch.no_grad(): 27 | zeroshot_weights = [] 28 | for classname in tqdm(dataset.classnames): 29 | texts = [] 30 | for t in template: 31 | texts.append(t(classname)) 32 | texts = open_clip.tokenize(texts).to(device) # tokenize 33 | embeddings = model.encode_text(texts) # embed with text encoder 34 | embeddings /= embeddings.norm(dim=-1, keepdim=True) 35 | 36 | embeddings = embeddings.mean(dim=0, keepdim=True) 37 | embeddings /= embeddings.norm() 38 | 39 | zeroshot_weights.append(embeddings) 40 | 41 | zeroshot_weights = torch.stack(zeroshot_weights, dim=0).to(device) 42 | zeroshot_weights = torch.transpose(zeroshot_weights, 0, 2) 43 | 44 | zeroshot_weights *= logit_scale.exp() 45 | 46 | zeroshot_weights = zeroshot_weights.squeeze().float() 47 | zeroshot_weights = torch.transpose(zeroshot_weights, 0, 1) 48 | 49 | classification_head = ClassificationHead(normalize=True, weights=zeroshot_weights) 50 | 51 | return classification_head 52 | 53 | 54 | def get_classification_head(args, dataset): 55 | filename = os.path.join(args.save, f'head_{dataset}.pt') 56 | if os.path.exists(filename): 57 | print(f'Classification head for {args.model} on {dataset} exists at {filename}') 58 | return ClassificationHead.load(filename) 59 | print(f'Did not find classification head for {args.model} on {dataset} at {filename}, building one from scratch.') 60 | model = ImageEncoder(args, keep_lang=True).model 61 | template = get_templates(dataset) 62 | classification_head = build_classification_head(model, dataset, template, args.data_location, args.device) 63 | os.makedirs(args.save, exist_ok=True) 64 | classification_head.save(filename) 65 | return classification_head 66 | 67 | -------------------------------------------------------------------------------- /src/main_layer_wise_adamerging.py: -------------------------------------------------------------------------------- 1 | import os 2 | os.environ["CUDA_VISIBLE_DEVICES"] = "3" 3 | 4 | import time 5 | import sys 6 | import tqdm 7 | sys.path.append('/home/taskarithmetic/') 8 | 9 | import torch 10 | from task_vectors import TaskVector 11 | from eval import eval_single_dataset, eval_single_dataset_head, eval_single_dataset_preprocess_head 12 | from args import parse_arguments 13 | def create_log_dir(path, filename='log.txt'): 14 | import logging 15 | if not os.path.exists(path): 16 | os.makedirs(path) 17 | logger = logging.getLogger(path) 18 | logger.setLevel(logging.DEBUG) 19 | fh = logging.FileHandler(path+'/'+filename) 20 | fh.setLevel(logging.DEBUG) 21 | ch = logging.StreamHandler() 22 | ch.setLevel(logging.DEBUG) 23 | logger.addHandler(fh) 24 | logger.addHandler(ch) 25 | return logger 26 | 27 | exam_datasets = ['SUN397', 'Cars', 'RESISC45', 'EuroSAT', 'SVHN', 'GTSRB', 'MNIST', 'DTD'] # SUN397 | Cars | RESISC45 | EuroSAT | SVHN | GTSRB | MNIST | DTD 28 | model = 'ViT-B-32' 29 | args = parse_arguments() 30 | args.data_location = '/home/taskarithmetic/data' 31 | args.model = model 32 | args.save = '/home/taskarithmetic/checkpoints/' + model 33 | args.logs_path = '/home/taskarithmetic/logs/' + model 34 | pretrained_checkpoint = '/home/taskarithmetic/checkpoints/'+model+'/zeroshot.pt' 35 | 36 | str_time_ = time.strftime('%Y%m%d_%H%M%S', time.localtime(time.time())) 37 | log = create_log_dir(args.logs_path, 'log_{}_Layer_wise_AdaMerging.txt'.format(str_time_)) 38 | args.log = log 39 | 40 | task_vectors = [TaskVector(pretrained_checkpoint, '/home/taskarithmetic/checkpoints/'+model+'/'+dataset_name+'/finetuned.pt') for dataset_name in exam_datasets] 41 | 42 | def del_attr(obj, names): 43 | if len(names) == 1: 44 | delattr(obj, names[0]) 45 | else: 46 | del_attr(getattr(obj, names[0]), names[1:]) 47 | 48 | def set_attr(obj, names, val): 49 | if len(names) == 1: 50 | setattr(obj, names[0], val) 51 | else: 52 | set_attr(getattr(obj, names[0]), names[1:], val) 53 | 54 | def make_functional(mod): 55 | orig_params = tuple(mod.parameters()) 56 | names = [] 57 | for name, p in list(mod.named_parameters()): 58 | del_attr(mod, name.split(".")) 59 | names.append(name) 60 | return orig_params, names 61 | 62 | def load_weights(mod, names, params): 63 | for name, p in zip(names, params): 64 | set_attr(mod, name.split("."), p) 65 | 66 | class ModelWrapper(torch.nn.Module): 67 | def __init__(self, model, initial_weights=None): 68 | super(ModelWrapper, self).__init__() 69 | self.model = model 70 | 71 | if hasattr(self.model, 'transformer'): 72 | delattr(self.model, 'transformer') 73 | 74 | def forward(self, images): 75 | features = self.model(images) 76 | return features 77 | 78 | from heads import get_classification_head 79 | class AdaMerging(torch.nn.Module): 80 | def __init__(self, paramslist, model, names, exam_datasets): 81 | super(AdaMerging, self).__init__() 82 | self.paramslist = paramslist 83 | self.model = model 84 | self.names = names 85 | self.pretrain_lambdas = torch.ones(len(paramslist[0]), 1) 86 | prior = 0.3 87 | rlambdas = torch.ones(len(paramslist[0]), len(paramslist)-1) * prior # (1 * tasks) 88 | self.lambdas_raw = torch.nn.Parameter(rlambdas) 89 | 90 | self.classifier = [] 91 | for dataset_name in exam_datasets: 92 | classification_head = get_classification_head(args, dataset_name) 93 | layer_name = 'classifier_{}'.format(dataset_name) 94 | self.add_module(layer_name, classification_head.to(args.device)) 95 | self.classifier.append(layer_name) 96 | 97 | def lambdas(self): 98 | task_lambdas = torch.clamp(self.lambdas_raw, min=0.0, max=1.0) 99 | lambdass = torch.cat((self.pretrain_lambdas, task_lambdas), 1) 100 | return lambdass 101 | 102 | def collect_trainable_params(self): 103 | return [self.lambdas_raw] 104 | 105 | def get_classification_head(self, dataset_name): 106 | layer_name = 'classifier_{}'.format(dataset_name) 107 | classification_head = getattr(self, layer_name) 108 | return classification_head 109 | 110 | def get_image_encoder(self): 111 | alph = self.lambdas() 112 | params = tuple(sum(tuple(pi * lambdasi for pi, lambdasi in zip(p, alph[j].cpu()))) for j, p in enumerate(zip(*self.paramslist))) 113 | params = tuple(p.cuda(0) for p in params) 114 | load_weights(self.model, self.names, params) 115 | return self.model 116 | 117 | def forward(self, inp, dataset_name): 118 | alph = self.lambdas() 119 | params = tuple(sum(tuple(pi * lambdasi for pi, lambdasi in zip(p, alph[j].cpu()))) for j, p in enumerate(zip(*self.paramslist))) 120 | 121 | params = tuple(p.cuda(0) for p in params) 122 | load_weights(self.model, self.names, params) 123 | feature = self.model(inp) 124 | 125 | layer_name = 'classifier_{}'.format(dataset_name) 126 | classification_head = getattr(self, layer_name) 127 | out = classification_head(feature) 128 | return out 129 | 130 | def softmax_entropy(x): 131 | return -(x.softmax(1) * x.log_softmax(1)).sum(1) 132 | 133 | pretrained_model = torch.load(pretrained_checkpoint) 134 | pretrained_model_dic = pretrained_model.state_dict() 135 | 136 | model = ModelWrapper(pretrained_model, exam_datasets) 137 | model = model.to(args.device) 138 | _, names = make_functional(model) 139 | 140 | paramslist = [] 141 | paramslist += [tuple(v.detach().requires_grad_().cpu() for _, v in pretrained_model_dic.items())] # pretrain 142 | paramslist += [tuple(v.detach().requires_grad_().cpu() for _, v in tv.vector.items()) for i, tv in enumerate(task_vectors)] # task vectors 143 | torch.cuda.empty_cache() 144 | adamerging_mtl_model = AdaMerging(paramslist, model, names, exam_datasets) 145 | 146 | print('init lambda:') 147 | print(adamerging_mtl_model.lambdas()) 148 | print('collect_trainable_params:') 149 | print(list(adamerging_mtl_model.collect_trainable_params())) 150 | 151 | epochs = 500 152 | optimizer = torch.optim.Adam(adamerging_mtl_model.collect_trainable_params(), lr=1e-3, betas=(0.9, 0.999), weight_decay=0.) 153 | 154 | from datasets.registry import get_dataset 155 | from datasets.common import get_dataloader, maybe_dictionarize, get_dataloader_shuffle 156 | 157 | Total_ACC = 0. 158 | for dataset_name in exam_datasets: 159 | image_encoder = adamerging_mtl_model.get_image_encoder() 160 | classification_head = adamerging_mtl_model.get_classification_head(dataset_name) 161 | metrics = eval_single_dataset_preprocess_head(image_encoder, classification_head, dataset_name, args) 162 | Total_ACC += metrics['top1'] 163 | log.info('Eval: init: ' + ' dataset: ' + str(dataset_name) + ' ACC: ' + str(metrics['top1'])) 164 | log.info('Eval: init: ' + ' Avg ACC:' + str(Total_ACC / len(exam_datasets)) + '\n') 165 | 166 | for epoch in range(epochs): 167 | losses = 0. 168 | for dataset_name in exam_datasets: 169 | dataset = get_dataset(dataset_name, pretrained_model.val_preprocess, location=args.data_location, batch_size=16) 170 | dataloader = get_dataloader_shuffle(dataset) 171 | 172 | for i, data in enumerate(tqdm.tqdm(dataloader)): 173 | data = maybe_dictionarize(data) 174 | x = data['images'].to(args.device) 175 | y = data['labels'].to(args.device) 176 | 177 | outputs = adamerging_mtl_model(x, dataset_name) 178 | loss = softmax_entropy(outputs).mean(0) 179 | losses += loss 180 | 181 | if i > 0: 182 | break 183 | 184 | optimizer.zero_grad() 185 | losses.backward() 186 | optimizer.step() 187 | 188 | print(list(adamerging_mtl_model.lambdas().data)) 189 | 190 | if ((epoch+1) % 500) == 0: 191 | log.info(str(list(adamerging_mtl_model.lambdas().data))) 192 | 193 | Total_ACC = 0. 194 | for dataset_name in exam_datasets: 195 | image_encoder = adamerging_mtl_model.get_image_encoder() 196 | classification_head = adamerging_mtl_model.get_classification_head(dataset_name) 197 | metrics = eval_single_dataset_preprocess_head(image_encoder, classification_head, dataset_name, args) 198 | Total_ACC += metrics['top1'] 199 | log.info('Eval: Epoch: ' + str(epoch) + ' dataset: ' + str(dataset_name) + ' ACC: ' + str(metrics['top1'])) 200 | log.info('Eval: Epoch: ' + str(epoch) + ' Avg ACC:' + str(Total_ACC / len(exam_datasets)) + '\n') 201 | -------------------------------------------------------------------------------- /src/main_layer_wise_adamergingpp.py: -------------------------------------------------------------------------------- 1 | import os 2 | os.environ["CUDA_VISIBLE_DEVICES"] = "3" 3 | 4 | import time 5 | import sys 6 | import tqdm 7 | sys.path.append('/home/taskarithmetic/') 8 | 9 | from eval import eval_single_dataset, eval_single_dataset_head, eval_single_dataset_preprocess_head 10 | from args import parse_arguments 11 | 12 | def create_log_dir(path, filename='log.txt'): 13 | import logging 14 | if not os.path.exists(path): 15 | os.makedirs(path) 16 | logger = logging.getLogger(path) 17 | logger.setLevel(logging.DEBUG) 18 | fh = logging.FileHandler(path+'/'+filename) 19 | fh.setLevel(logging.DEBUG) 20 | ch = logging.StreamHandler() 21 | ch.setLevel(logging.DEBUG) 22 | logger.addHandler(fh) 23 | logger.addHandler(ch) 24 | return logger 25 | 26 | exam_datasets = ['SUN397', 'Cars', 'RESISC45', 'EuroSAT', 'SVHN', 'GTSRB', 'MNIST', 'DTD'] # SUN397 | Cars | RESISC45 | EuroSAT | SVHN | GTSRB | MNIST | DTD 27 | model = 'ViT-B-32' 28 | args = parse_arguments() 29 | args.data_location = '/home/taskarithmetic/data' 30 | args.model = model 31 | args.save = '/home/taskarithmetic/checkpoints/' + model 32 | args.logs_path = '/home/taskarithmetic/logs/' + model 33 | pretrained_checkpoint = '/home/taskarithmetic/checkpoints/'+model+'/zeroshot.pt' 34 | 35 | str_time_ = time.strftime('%Y%m%d_%H%M%S', time.localtime(time.time())) 36 | log = create_log_dir(args.logs_path, 'log_{}_Layer_wise_AdaMergingPP.txt'.format(str_time_)) 37 | args.log = log 38 | 39 | from ties_merging_utils import * 40 | 41 | ft_checks = [torch.load('/home/taskarithmetic/checkpoints/'+model+'/'+dataset_name+'/finetuned.pt').state_dict() for dataset_name in exam_datasets] 42 | ptm_check = torch.load(pretrained_checkpoint).state_dict() 43 | 44 | check_parameterNamesMatch(ft_checks + [ptm_check]) 45 | 46 | remove_keys = [] 47 | print(f"Flattening out Checkpoints") 48 | flat_ft = torch.vstack([state_dict_to_vector(check, remove_keys) for check in ft_checks]) 49 | flat_ptm = state_dict_to_vector(ptm_check, remove_keys) 50 | 51 | tv_flat_checks = flat_ft - flat_ptm 52 | 53 | assert check_state_dicts_equal(vector_to_state_dict(flat_ptm, ptm_check, remove_keys), ptm_check) 54 | assert all([check_state_dicts_equal(vector_to_state_dict(flat_ft[i], ptm_check, remove_keys), ft_checks[i])for i in range(len(ft_checks))]) 55 | 56 | 57 | K = 20 58 | merge_func = "dis-sum" 59 | 60 | selected_entries, merged_tv = ties_merging_split(tv_flat_checks, reset_thresh=K, merge_func=merge_func,) 61 | 62 | ties_task_vectors = [] 63 | for vector_ in selected_entries: 64 | t_state_dict = vector_to_state_dict(vector_, ptm_check, remove_keys=remove_keys) 65 | ref_model = torch.load(pretrained_checkpoint) 66 | ref_model.load_state_dict(t_state_dict, strict=False) 67 | ties_task_vectors.append(ref_model.state_dict()) 68 | 69 | def del_attr(obj, names): 70 | if len(names) == 1: 71 | delattr(obj, names[0]) 72 | else: 73 | del_attr(getattr(obj, names[0]), names[1:]) 74 | 75 | def set_attr(obj, names, val): 76 | if len(names) == 1: 77 | setattr(obj, names[0], val) 78 | else: 79 | set_attr(getattr(obj, names[0]), names[1:], val) 80 | 81 | def make_functional(mod): 82 | orig_params = tuple(mod.parameters()) 83 | names = [] 84 | for name, p in list(mod.named_parameters()): 85 | del_attr(mod, name.split(".")) 86 | names.append(name) 87 | return orig_params, names 88 | 89 | def load_weights(mod, names, params): 90 | for name, p in zip(names, params): 91 | set_attr(mod, name.split("."), p) 92 | 93 | 94 | class ModelWrapper(torch.nn.Module): 95 | def __init__(self, model, initial_weights=None): 96 | super(ModelWrapper, self).__init__() 97 | self.model = model 98 | 99 | if hasattr(self.model, 'transformer'): 100 | delattr(self.model, 'transformer') 101 | 102 | def forward(self, images): 103 | features = self.model(images) 104 | return features 105 | 106 | from heads import get_classification_head 107 | class AdaMerging(torch.nn.Module): 108 | def __init__(self, paramslist, model, names, exam_datasets): 109 | super(AdaMerging, self).__init__() 110 | self.paramslist = paramslist 111 | self.model = model 112 | self.names = names 113 | self.pretrain_lambdas = torch.ones(len(paramslist[0]), 1) 114 | prior = 0.3 115 | rlambdas = torch.ones(len(paramslist[0]), len(paramslist)-1) * prior # (1 * tasks) 116 | self.lambdas_raw = torch.nn.Parameter(rlambdas) 117 | 118 | self.classifier = [] 119 | for dataset_name in exam_datasets: 120 | classification_head = get_classification_head(args, dataset_name) 121 | layer_name = 'classifier_{}'.format(dataset_name) 122 | self.add_module(layer_name, classification_head.to(args.device)) 123 | self.classifier.append(layer_name) 124 | 125 | def lambdas(self): 126 | task_lambdas = torch.clamp(self.lambdas_raw, min=0.0, max=1.0) 127 | lambdass = torch.cat((self.pretrain_lambdas, task_lambdas), 1) 128 | return lambdass 129 | 130 | def collect_trainable_params(self): 131 | return [self.lambdas_raw] 132 | 133 | def get_classification_head(self, dataset_name): 134 | layer_name = 'classifier_{}'.format(dataset_name) 135 | classification_head = getattr(self, layer_name) 136 | return classification_head 137 | 138 | def get_image_encoder(self): 139 | alph = self.lambdas() 140 | params = tuple(sum(tuple(pi * lambdasi for pi, lambdasi in zip(p, alph[j].cpu()))) for j, p in enumerate(zip(*self.paramslist))) 141 | params = tuple(p.cuda(0) for p in params) 142 | load_weights(self.model, self.names, params) 143 | return self.model 144 | 145 | def forward(self, inp, dataset_name): 146 | alph = self.lambdas() 147 | params = tuple(sum(tuple(pi * lambdasi for pi, lambdasi in zip(p, alph[j].cpu()))) for j, p in enumerate(zip(*self.paramslist))) 148 | 149 | params = tuple(p.cuda(0) for p in params) 150 | load_weights(self.model, self.names, params) 151 | feature = self.model(inp) 152 | 153 | layer_name = 'classifier_{}'.format(dataset_name) 154 | classification_head = getattr(self, layer_name) 155 | out = classification_head(feature) 156 | 157 | return out 158 | 159 | def softmax_entropy(x): 160 | return -(x.softmax(1) * x.log_softmax(1)).sum(1) 161 | 162 | pretrained_model = torch.load(pretrained_checkpoint) 163 | pretrained_model_dic = pretrained_model.state_dict() 164 | 165 | model = ModelWrapper(pretrained_model, exam_datasets) 166 | model = model.to(args.device) 167 | _, names = make_functional(model) 168 | 169 | paramslist = [] 170 | paramslist += [tuple(v.detach().requires_grad_().cpu() for _, v in pretrained_model_dic.items())] # pretrain 171 | paramslist += [tuple(v.detach().requires_grad_().cpu() for _, v in tv.items()) for i, tv in enumerate(ties_task_vectors)] # task vectors 172 | torch.cuda.empty_cache() 173 | adamerging_mtl_model = AdaMerging(paramslist, model, names, exam_datasets) 174 | 175 | print('init lambda:') 176 | print(adamerging_mtl_model.lambdas()) 177 | print('collect_trainable_params:') 178 | print(list(adamerging_mtl_model.collect_trainable_params())) 179 | 180 | epochs = 500 181 | optimizer = torch.optim.Adam(adamerging_mtl_model.collect_trainable_params(), lr=1e-3, betas=(0.9, 0.999), weight_decay=0.) 182 | 183 | from datasets.registry import get_dataset 184 | from datasets.common import get_dataloader, maybe_dictionarize, get_dataloader_shuffle 185 | 186 | Total_ACC = 0. 187 | for dataset_name in exam_datasets: 188 | image_encoder = adamerging_mtl_model.get_image_encoder() 189 | classification_head = adamerging_mtl_model.get_classification_head(dataset_name) 190 | metrics = eval_single_dataset_preprocess_head(image_encoder, classification_head, dataset_name, args) 191 | Total_ACC += metrics['top1'] 192 | log.info('Eval: init: ' + ' dataset: ' + str(dataset_name) + ' ACC: ' + str(metrics['top1'])) 193 | log.info('Eval: init: ' + ' Avg ACC:' + str(Total_ACC / len(exam_datasets)) + '\n') 194 | 195 | for epoch in range(epochs): 196 | losses = 0. 197 | for dataset_name in exam_datasets: 198 | dataset = get_dataset(dataset_name, pretrained_model.val_preprocess, location=args.data_location, batch_size=16) 199 | dataloader = get_dataloader_shuffle(dataset) 200 | 201 | for i, data in enumerate(tqdm.tqdm(dataloader)): 202 | data = maybe_dictionarize(data) 203 | x = data['images'].to(args.device) 204 | y = data['labels'].to(args.device) 205 | 206 | outputs = adamerging_mtl_model(x, dataset_name) 207 | loss = softmax_entropy(outputs).mean(0) 208 | losses += loss 209 | 210 | if i > 0: 211 | break 212 | 213 | optimizer.zero_grad() 214 | losses.backward() 215 | optimizer.step() 216 | 217 | print(list(adamerging_mtl_model.lambdas().data)) 218 | 219 | if ((epoch+1) % 500) == 0: 220 | log.info(str(list(adamerging_mtl_model.lambdas().data))) 221 | 222 | Total_ACC = 0. 223 | for dataset_name in exam_datasets: 224 | image_encoder = adamerging_mtl_model.get_image_encoder() 225 | classification_head = adamerging_mtl_model.get_classification_head(dataset_name) 226 | metrics = eval_single_dataset_preprocess_head(image_encoder, classification_head, dataset_name, args) 227 | Total_ACC += metrics['top1'] 228 | log.info('Eval: Epoch: ' + str(epoch) + ' dataset: ' + str(dataset_name) + ' ACC: ' + str(metrics['top1'])) 229 | log.info('Eval: Epoch: ' + str(epoch) + ' Avg ACC:' + str(Total_ACC / len(exam_datasets)) + '\n') 230 | -------------------------------------------------------------------------------- /src/main_task_arithmetic.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import numpy as np 4 | 5 | os.environ["CUDA_VISIBLE_DEVICES"] = "0" 6 | 7 | import time 8 | import sys 9 | sys.path.append('/home/taskarithmetic/') 10 | 11 | from task_vectors import TaskVector 12 | from eval import eval_single_dataset 13 | from args import parse_arguments 14 | 15 | def create_log_dir(path, filename='log.txt'): 16 | import logging 17 | if not os.path.exists(path): 18 | os.makedirs(path) 19 | logger = logging.getLogger(path) 20 | logger.setLevel(logging.DEBUG) 21 | fh = logging.FileHandler(path+'/'+filename) 22 | fh.setLevel(logging.DEBUG) 23 | ch = logging.StreamHandler() 24 | ch.setLevel(logging.DEBUG) 25 | logger.addHandler(fh) 26 | logger.addHandler(ch) 27 | return logger 28 | 29 | exam_datasets = ['SUN397', 'Cars', 'RESISC45', 'EuroSAT', 'SVHN', 'GTSRB', 'MNIST', 'DTD'] # SUN397 | Cars | RESISC45 | EuroSAT | SVHN | GTSRB | MNIST | DTD 30 | model = 'ViT-B-32' 31 | args = parse_arguments() 32 | args.data_location = '/home/taskarithmetic/data' 33 | args.model = model 34 | args.save = '/home/taskarithmetic/checkpoints/' + model 35 | args.logs_path = '/home/taskarithmetic/logs/' + model 36 | pretrained_checkpoint = '/home/taskarithmetic/checkpoints/'+model+'/zeroshot.pt' 37 | 38 | str_time_ = time.strftime('%Y%m%d_%H%M%S', time.localtime(time.time())) 39 | log = create_log_dir(args.logs_path, 'log_{}_task_arithmetic.txt'.format(str_time_)) 40 | 41 | task_vectors = [ 42 | TaskVector(pretrained_checkpoint, '/home/taskarithmetic/checkpoints/'+model+'/'+dataset_name+'/finetuned.pt') for dataset_name in exam_datasets 43 | ] 44 | 45 | task_vector_sum = sum(task_vectors) 46 | 47 | scaling_coef_ = 0.3 48 | 49 | image_encoder = task_vector_sum.apply_to(pretrained_checkpoint, scaling_coef=scaling_coef_) 50 | log.info('*'*20 + 'scaling_coef:' + str(scaling_coef_) + '*'*20) 51 | 52 | accs = [] 53 | for dataset in exam_datasets: 54 | metrics = eval_single_dataset(image_encoder, dataset, args) 55 | log.info(str(dataset) + ':' + str(metrics.get('top1')*100)+'%') 56 | accs.append(metrics.get('top1')*100) 57 | log.info('Avg ACC:' + str(np.mean(accs)) + '%') 58 | -------------------------------------------------------------------------------- /src/main_task_wise_adamerging.py: -------------------------------------------------------------------------------- 1 | import os 2 | os.environ["CUDA_VISIBLE_DEVICES"] = "0" 3 | 4 | import time 5 | import sys 6 | import tqdm 7 | sys.path.append('/root/taskarithmetic') 8 | 9 | import torch 10 | from task_vectors import TaskVector 11 | from eval import eval_single_dataset, eval_single_dataset_head, eval_single_dataset_preprocess_head 12 | from args import parse_arguments 13 | 14 | def create_log_dir(path, filename='log.txt'): 15 | import logging 16 | if not os.path.exists(path): 17 | os.makedirs(path) 18 | logger = logging.getLogger(path) 19 | logger.setLevel(logging.DEBUG) 20 | fh = logging.FileHandler(path+'/'+filename) 21 | fh.setLevel(logging.DEBUG) 22 | ch = logging.StreamHandler() 23 | ch.setLevel(logging.DEBUG) 24 | logger.addHandler(fh) 25 | logger.addHandler(ch) 26 | return logger 27 | 28 | exam_datasets = ['SUN397', 'Cars', 'RESISC45', 'EuroSAT', 'SVHN', 'GTSRB', 'MNIST', 'DTD'] # SUN397 | Cars | RESISC45 | EuroSAT | SVHN | GTSRB | MNIST | DTD 29 | model = 'ViT-B-32' 30 | source_root_path = '/root' 31 | args = parse_arguments() 32 | args.data_location = source_root_path+'/dataset' 33 | args.model = model 34 | args.save = source_root_path+'/checkpoint/' + model 35 | args.logs_path = '/root/taskarithmetic/src/logs/' + model 36 | pretrained_checkpoint = source_root_path+'/checkpoint/'+model+'/zeroshot.pt' 37 | 38 | str_time_ = time.strftime('%Y%m%d_%H%M%S', time.localtime(time.time())) 39 | log = create_log_dir(args.logs_path, 'log_{}_Task_wise_AdaMerging.txt'.format(str_time_)) 40 | args.log = log 41 | 42 | task_vectors = [TaskVector(pretrained_checkpoint, source_root_path+'/checkpoint/'+model+'/'+dataset_name+'/finetuned.pt') for dataset_name in exam_datasets] 43 | 44 | def del_attr(obj, names): 45 | if len(names) == 1: 46 | delattr(obj, names[0]) 47 | else: 48 | del_attr(getattr(obj, names[0]), names[1:]) 49 | 50 | def set_attr(obj, names, val): 51 | if len(names) == 1: 52 | setattr(obj, names[0], val) 53 | else: 54 | set_attr(getattr(obj, names[0]), names[1:], val) 55 | 56 | def make_functional(mod): 57 | orig_params = tuple(mod.parameters()) 58 | names = [] 59 | for name, p in list(mod.named_parameters()): 60 | del_attr(mod, name.split(".")) 61 | names.append(name) 62 | return orig_params, names 63 | 64 | def load_weights(mod, names, params): 65 | for name, p in zip(names, params): 66 | set_attr(mod, name.split("."), p) 67 | 68 | 69 | class ModelWrapper(torch.nn.Module): 70 | def __init__(self, model, initial_weights=None): 71 | super(ModelWrapper, self).__init__() 72 | self.model = model 73 | 74 | if hasattr(self.model, 'transformer'): 75 | delattr(self.model, 'transformer') 76 | 77 | def forward(self, images): 78 | features = self.model(images) 79 | return features 80 | 81 | from heads import get_classification_head 82 | class AdaMerging(torch.nn.Module): 83 | def __init__(self, paramslist, model, names, exam_datasets): 84 | super(AdaMerging, self).__init__() 85 | self.paramslist = paramslist 86 | self.model = model 87 | self.names = names 88 | self.pretrain_lambdas = torch.ones(1, 1) 89 | prior = 0.3 90 | rlambdas = torch.ones(1, len(paramslist)-1) * prior # (1 * tasks) 91 | self.lambdas_raw = torch.nn.Parameter(rlambdas) 92 | 93 | self.classifier = [] 94 | for dataset_name in exam_datasets: 95 | classification_head = get_classification_head(args, dataset_name) 96 | layer_name = 'classifier_{}'.format(dataset_name) 97 | self.add_module(layer_name, classification_head.to(args.device)) 98 | self.classifier.append(layer_name) 99 | 100 | def lambdas(self): 101 | task_lambdas = torch.clamp(self.lambdas_raw, min=0.0, max=1.0) 102 | lambdass = torch.cat((self.pretrain_lambdas, task_lambdas), 1) 103 | return lambdass 104 | 105 | def collect_trainable_params(self): 106 | return [self.lambdas_raw] 107 | 108 | def get_classification_head(self, dataset_name): 109 | layer_name = 'classifier_{}'.format(dataset_name) 110 | classification_head = getattr(self, layer_name) 111 | return classification_head 112 | 113 | def get_image_encoder(self): 114 | alph = self.lambdas() 115 | params = tuple(sum(tuple(pi * lambdasi for pi, lambdasi in zip(p, alph[0].cpu()))) for j, p in enumerate(zip(*self.paramslist))) 116 | params = tuple(p.cuda(0) for p in params) 117 | load_weights(self.model, self.names, params) 118 | return self.model 119 | 120 | def forward(self, inp, dataset_name): 121 | alph = self.lambdas() 122 | params = tuple(sum(tuple(pi * lambdasi for pi, lambdasi in zip(p, alph[0].cpu()))) for j, p in enumerate(zip(*self.paramslist))) 123 | 124 | params = tuple(p.cuda(0) for p in params) 125 | load_weights(self.model, self.names, params) 126 | feature = self.model(inp) 127 | 128 | layer_name = 'classifier_{}'.format(dataset_name) 129 | classification_head = getattr(self, layer_name) 130 | out = classification_head(feature) 131 | 132 | return out 133 | 134 | def softmax_entropy(x): 135 | return -(x.softmax(1) * x.log_softmax(1)).sum(1) 136 | 137 | pretrained_model = torch.load(pretrained_checkpoint) 138 | pretrained_model_dic = pretrained_model.state_dict() 139 | 140 | model = ModelWrapper(pretrained_model, exam_datasets) 141 | model = model.to(args.device) 142 | _, names = make_functional(model) 143 | 144 | paramslist = [] 145 | paramslist += [tuple(v.detach().requires_grad_().cpu() for _, v in pretrained_model_dic.items())] # pretrain 146 | paramslist += [tuple(v.detach().requires_grad_().cpu() for _, v in tv.vector.items()) for i, tv in enumerate(task_vectors)] # task vectors 147 | torch.cuda.empty_cache() 148 | adamerging_mtl_model = AdaMerging(paramslist, model, names, exam_datasets) 149 | 150 | print('init lambda:') 151 | print(adamerging_mtl_model.lambdas()) 152 | print('collect_trainable_params:') 153 | print(list(adamerging_mtl_model.collect_trainable_params())) 154 | 155 | epochs = 500 156 | optimizer = torch.optim.Adam(adamerging_mtl_model.collect_trainable_params(), lr=1e-3, betas=(0.9, 0.999), weight_decay=0.) 157 | 158 | from datasets.registry import get_dataset 159 | from datasets.common import get_dataloader, maybe_dictionarize, get_dataloader_shuffle 160 | 161 | Total_ACC = 0. 162 | for dataset_name in exam_datasets: 163 | image_encoder = adamerging_mtl_model.get_image_encoder() 164 | classification_head = adamerging_mtl_model.get_classification_head(dataset_name) 165 | metrics = eval_single_dataset_preprocess_head(image_encoder, classification_head, dataset_name, args) 166 | Total_ACC += metrics['top1'] 167 | log.info('Eval: init: ' + ' dataset: ' + str(dataset_name) + ' ACC: ' + str(metrics['top1'])) 168 | log.info('Eval: init: ' + ' Avg ACC:' + str(Total_ACC / len(exam_datasets)) + '\n') 169 | 170 | for epoch in range(epochs): 171 | losses = 0. 172 | for dataset_name in exam_datasets: 173 | dataset = get_dataset(dataset_name, pretrained_model.val_preprocess, location=args.data_location, batch_size=16) 174 | dataloader = get_dataloader_shuffle(dataset) 175 | 176 | for i, data in enumerate(tqdm.tqdm(dataloader)): 177 | data = maybe_dictionarize(data) 178 | x = data['images'].to(args.device) 179 | y = data['labels'].to(args.device) 180 | 181 | outputs = adamerging_mtl_model(x, dataset_name) 182 | loss = softmax_entropy(outputs).mean(0) 183 | losses += loss 184 | 185 | if i > 0: 186 | break 187 | 188 | optimizer.zero_grad() 189 | losses.backward() 190 | optimizer.step() 191 | 192 | if ((epoch+1) % 500) == 0: 193 | log.info(str(list(adamerging_mtl_model.lambdas().data))) 194 | 195 | Total_ACC = 0. 196 | for dataset_name in exam_datasets: 197 | image_encoder = adamerging_mtl_model.get_image_encoder() 198 | classification_head = adamerging_mtl_model.get_classification_head(dataset_name) 199 | metrics = eval_single_dataset_preprocess_head(image_encoder, classification_head, dataset_name, args) 200 | Total_ACC += metrics['top1'] 201 | log.info('Eval: Epoch: ' + str(epoch) + ' dataset: ' + str(dataset_name) + ' ACC: ' + str(metrics['top1'])) 202 | log.info('Eval: Epoch: ' + str(epoch) + ' Avg ACC:' + str(Total_ACC / len(exam_datasets)) + '\n') 203 | -------------------------------------------------------------------------------- /src/main_task_wise_adamergingpp.py: -------------------------------------------------------------------------------- 1 | import os 2 | os.environ["CUDA_VISIBLE_DEVICES"] = "0" 3 | 4 | import time 5 | import sys 6 | import tqdm 7 | sys.path.append('/home/taskarithmetic/') 8 | 9 | from eval import eval_single_dataset, eval_single_dataset_head, eval_single_dataset_preprocess_head 10 | from args import parse_arguments 11 | 12 | def create_log_dir(path, filename='log.txt'): 13 | import logging 14 | if not os.path.exists(path): 15 | os.makedirs(path) 16 | logger = logging.getLogger(path) 17 | logger.setLevel(logging.DEBUG) 18 | fh = logging.FileHandler(path+'/'+filename) 19 | fh.setLevel(logging.DEBUG) 20 | ch = logging.StreamHandler() 21 | ch.setLevel(logging.DEBUG) 22 | logger.addHandler(fh) 23 | logger.addHandler(ch) 24 | return logger 25 | 26 | exam_datasets = ['SUN397', 'Cars', 'RESISC45', 'EuroSAT', 'SVHN', 'GTSRB', 'MNIST', 'DTD'] # SUN397 | Cars | RESISC45 | EuroSAT | SVHN | GTSRB | MNIST | DTD 27 | model = 'ViT-B-32' 28 | args = parse_arguments() 29 | args.data_location = '/home/taskarithmetic/data' 30 | args.model = model 31 | args.save = '/home/taskarithmetic/checkpoints/' + model 32 | args.logs_path = '/home/taskarithmetic/logs/' + model 33 | pretrained_checkpoint = '/home/taskarithmetic/checkpoints/'+model+'/zeroshot.pt' 34 | 35 | str_time_ = time.strftime('%Y%m%d_%H%M%S', time.localtime(time.time())) 36 | log = create_log_dir(args.logs_path, 'log_{}_Task_wise_AdaMergingPP.txt'.format(str_time_)) 37 | args.log = log 38 | 39 | from ties_merging_utils import * 40 | 41 | ft_checks = [torch.load('/home/taskarithmetic/checkpoints/'+model+'/'+dataset_name+'/finetuned.pt').state_dict() for dataset_name in exam_datasets] 42 | ptm_check = torch.load(pretrained_checkpoint).state_dict() 43 | 44 | check_parameterNamesMatch(ft_checks + [ptm_check]) 45 | 46 | remove_keys = [] 47 | 48 | print(f"Flattening out Checkpoints") 49 | flat_ft = torch.vstack([state_dict_to_vector(check, remove_keys) for check in ft_checks]) 50 | flat_ptm = state_dict_to_vector(ptm_check, remove_keys) 51 | 52 | tv_flat_checks = flat_ft - flat_ptm 53 | 54 | assert check_state_dicts_equal(vector_to_state_dict(flat_ptm, ptm_check, remove_keys), ptm_check) 55 | assert all([check_state_dicts_equal(vector_to_state_dict(flat_ft[i], ptm_check, remove_keys), ft_checks[i])for i in range(len(ft_checks))]) 56 | 57 | K = 20 58 | merge_func = "dis-sum" 59 | 60 | selected_entries, merged_tv = ties_merging_split(tv_flat_checks, reset_thresh=K, merge_func=merge_func,) 61 | 62 | ties_task_vectors = [] 63 | for vector_ in selected_entries: 64 | t_state_dict = vector_to_state_dict(vector_, ptm_check, remove_keys=remove_keys) 65 | ref_model = torch.load(pretrained_checkpoint) 66 | ref_model.load_state_dict(t_state_dict, strict=False) 67 | ties_task_vectors.append(ref_model.state_dict()) 68 | 69 | def del_attr(obj, names): 70 | if len(names) == 1: 71 | delattr(obj, names[0]) 72 | else: 73 | del_attr(getattr(obj, names[0]), names[1:]) 74 | 75 | def set_attr(obj, names, val): 76 | if len(names) == 1: 77 | setattr(obj, names[0], val) 78 | else: 79 | set_attr(getattr(obj, names[0]), names[1:], val) 80 | 81 | def make_functional(mod): 82 | orig_params = tuple(mod.parameters()) 83 | # Remove all the parameters in the model 84 | names = [] 85 | for name, p in list(mod.named_parameters()): 86 | del_attr(mod, name.split(".")) 87 | names.append(name) 88 | return orig_params, names 89 | 90 | def load_weights(mod, names, params): 91 | for name, p in zip(names, params): 92 | set_attr(mod, name.split("."), p) 93 | 94 | 95 | class ModelWrapper(torch.nn.Module): 96 | def __init__(self, model, initial_weights=None): 97 | super(ModelWrapper, self).__init__() 98 | self.model = model 99 | 100 | if hasattr(self.model, 'transformer'): 101 | delattr(self.model, 'transformer') 102 | 103 | def forward(self, images): 104 | features = self.model(images) 105 | return features 106 | 107 | from heads import get_classification_head 108 | class AdaMerging(torch.nn.Module): 109 | def __init__(self, paramslist, model, names, exam_datasets): 110 | super(AdaMerging, self).__init__() 111 | self.paramslist = paramslist 112 | self.model = model 113 | self.names = names 114 | self.pretrain_lambdas = torch.ones(1, 1) 115 | prior = 0.3 116 | rlambdas = torch.ones(1, len(paramslist)-1) * prior # (1 * tasks) 117 | self.lambdas_raw = torch.nn.Parameter(rlambdas) 118 | 119 | self.classifier = [] 120 | for dataset_name in exam_datasets: 121 | classification_head = get_classification_head(args, dataset_name) 122 | layer_name = 'classifier_{}'.format(dataset_name) 123 | self.add_module(layer_name, classification_head.to(args.device)) 124 | self.classifier.append(layer_name) 125 | 126 | def lambdas(self): 127 | task_lambdas = torch.clamp(self.lambdas_raw, min=0.0, max=1.0) 128 | lambdass = torch.cat((self.pretrain_lambdas, task_lambdas), 1) 129 | return lambdass 130 | 131 | def collect_trainable_params(self): 132 | return [self.lambdas_raw] 133 | 134 | def get_classification_head(self, dataset_name): 135 | layer_name = 'classifier_{}'.format(dataset_name) 136 | classification_head = getattr(self, layer_name) 137 | return classification_head 138 | 139 | def get_image_encoder(self): 140 | alph = self.lambdas() 141 | params = tuple(sum(tuple(pi * lambdasi for pi, lambdasi in zip(p, alph[0].cpu()))) for j, p in enumerate(zip(*self.paramslist))) 142 | params = tuple(p.cuda(0) for p in params) 143 | load_weights(self.model, self.names, params) 144 | return self.model 145 | 146 | def forward(self, inp, dataset_name): 147 | alph = self.lambdas() 148 | params = tuple(sum(tuple(pi * lambdasi for pi, lambdasi in zip(p, alph[0].cpu()))) for j, p in enumerate(zip(*self.paramslist))) 149 | 150 | params = tuple(p.cuda(0) for p in params) 151 | load_weights(self.model, self.names, params) 152 | feature = self.model(inp) 153 | 154 | layer_name = 'classifier_{}'.format(dataset_name) 155 | classification_head = getattr(self, layer_name) 156 | out = classification_head(feature) 157 | 158 | return out 159 | 160 | def softmax_entropy(x): 161 | return -(x.softmax(1) * x.log_softmax(1)).sum(1) 162 | 163 | pretrained_model = torch.load(pretrained_checkpoint) 164 | pretrained_model_dic = pretrained_model.state_dict() 165 | 166 | model = ModelWrapper(pretrained_model, exam_datasets) 167 | model = model.to(args.device) 168 | _, names = make_functional(model) 169 | 170 | paramslist = [] 171 | paramslist += [tuple(v.detach().requires_grad_().cpu() for _, v in pretrained_model_dic.items())] # pretrain 172 | paramslist += [tuple(v.detach().requires_grad_().cpu() for _, v in tv.items()) for i, tv in enumerate(ties_task_vectors)] # task vectors 173 | torch.cuda.empty_cache() 174 | adamerging_mtl_model = AdaMerging(paramslist, model, names, exam_datasets) 175 | 176 | print('init lambda:') 177 | print(adamerging_mtl_model.lambdas()) 178 | print('collect_trainable_params:') 179 | print(list(adamerging_mtl_model.collect_trainable_params())) 180 | 181 | epochs = 500 182 | optimizer = torch.optim.Adam(adamerging_mtl_model.collect_trainable_params(), lr=1e-3, betas=(0.9, 0.999), weight_decay=0.) 183 | 184 | from datasets.registry import get_dataset 185 | from datasets.common import get_dataloader, maybe_dictionarize, get_dataloader_shuffle 186 | 187 | Total_ACC = 0. 188 | for dataset_name in exam_datasets: 189 | image_encoder = adamerging_mtl_model.get_image_encoder() 190 | classification_head = adamerging_mtl_model.get_classification_head(dataset_name) 191 | metrics = eval_single_dataset_preprocess_head(image_encoder, classification_head, dataset_name, args) 192 | Total_ACC += metrics['top1'] 193 | log.info('Eval: init: ' + ' dataset: ' + str(dataset_name) + ' ACC: ' + str(metrics['top1'])) 194 | log.info('Eval: init: ' + ' Avg ACC:' + str(Total_ACC / len(exam_datasets)) + '\n') 195 | 196 | for epoch in range(epochs): 197 | losses = 0. 198 | for dataset_name in exam_datasets: 199 | dataset = get_dataset(dataset_name, pretrained_model.val_preprocess, location=args.data_location, batch_size=16) 200 | dataloader = get_dataloader_shuffle(dataset) 201 | 202 | for i, data in enumerate(tqdm.tqdm(dataloader)): 203 | data = maybe_dictionarize(data) 204 | x = data['images'].to(args.device) 205 | y = data['labels'].to(args.device) 206 | 207 | outputs = adamerging_mtl_model(x, dataset_name) 208 | loss = softmax_entropy(outputs).mean(0) 209 | losses += loss 210 | 211 | if i > 0: 212 | break 213 | 214 | optimizer.zero_grad() 215 | losses.backward() 216 | optimizer.step() 217 | 218 | print(list(adamerging_mtl_model.lambdas().data)) 219 | 220 | if ((epoch+1) % 500) == 0: 221 | log.info(str(list(adamerging_mtl_model.lambdas().data))) 222 | 223 | Total_ACC = 0. 224 | for dataset_name in exam_datasets: 225 | image_encoder = adamerging_mtl_model.get_image_encoder() 226 | classification_head = adamerging_mtl_model.get_classification_head(dataset_name) 227 | metrics = eval_single_dataset_preprocess_head(image_encoder, classification_head, dataset_name, args) 228 | Total_ACC += metrics['top1'] 229 | log.info('Eval: Epoch: ' + str(epoch) + ' dataset: ' + str(dataset_name) + ' ACC: ' + str(metrics['top1'])) 230 | log.info('Eval: Epoch: ' + str(epoch) + ' Avg ACC:' + str(Total_ACC / len(exam_datasets)) + '\n') 231 | -------------------------------------------------------------------------------- /src/main_ties_merging.py: -------------------------------------------------------------------------------- 1 | import os 2 | os.environ["CUDA_VISIBLE_DEVICES"] = "3" 3 | 4 | import time 5 | import sys 6 | sys.path.append('/home/taskarithmetic/') 7 | 8 | from eval import eval_single_dataset 9 | from args import parse_arguments 10 | 11 | def create_log_dir(path, filename='log.txt'): 12 | import logging 13 | if not os.path.exists(path): 14 | os.makedirs(path) 15 | logger = logging.getLogger(path) 16 | logger.setLevel(logging.DEBUG) 17 | fh = logging.FileHandler(path+'/'+filename) 18 | fh.setLevel(logging.DEBUG) 19 | ch = logging.StreamHandler() 20 | ch.setLevel(logging.DEBUG) 21 | logger.addHandler(fh) 22 | logger.addHandler(ch) 23 | return logger 24 | 25 | exam_datasets = ['SUN397', 'Cars', 'RESISC45', 'EuroSAT', 'SVHN', 'GTSRB', 'MNIST', 'DTD'] # SUN397 | Cars | RESISC45 | EuroSAT | SVHN | GTSRB | MNIST | DTD 26 | model = 'ViT-B-32' 27 | args = parse_arguments() 28 | args.data_location = '/home/taskarithmetic/data' 29 | args.model = model 30 | args.save = '/home/taskarithmetic/checkpoints/' + model 31 | args.logs_path = '/home/taskarithmetic/logs/' + model 32 | pretrained_checkpoint = '/home/taskarithmetic/checkpoints/'+model+'/zeroshot.pt' 33 | 34 | str_time_ = time.strftime('%Y%m%d_%H%M%S', time.localtime(time.time())) 35 | log = create_log_dir(args.logs_path, 'log_{}_ties_merging.txt'.format(str_time_)) 36 | 37 | from ties_merging_utils import * 38 | ft_checks = [torch.load('/home/taskarithmetic/checkpoints/'+model+'/'+dataset_name+'/finetuned.pt').state_dict() for dataset_name in exam_datasets] 39 | ptm_check = torch.load(pretrained_checkpoint).state_dict() 40 | check_parameterNamesMatch(ft_checks + [ptm_check]) 41 | 42 | remove_keys = [] 43 | print(f"Flattening out Checkpoints") 44 | flat_ft = torch.vstack([state_dict_to_vector(check, remove_keys) for check in ft_checks]) 45 | flat_ptm = state_dict_to_vector(ptm_check, remove_keys) 46 | 47 | tv_flat_checks = flat_ft - flat_ptm 48 | assert check_state_dicts_equal(vector_to_state_dict(flat_ptm, ptm_check, remove_keys), ptm_check) 49 | assert all([check_state_dicts_equal(vector_to_state_dict(flat_ft[i], ptm_check, remove_keys), ft_checks[i])for i in range(len(ft_checks))]) 50 | 51 | 52 | K = 20 53 | merge_func = "dis-sum" 54 | scaling_coef_ = 0.3 55 | 56 | merged_tv = ties_merging(tv_flat_checks, reset_thresh=K, merge_func=merge_func,) 57 | merged_check = flat_ptm + scaling_coef_ * merged_tv 58 | merged_state_dict = vector_to_state_dict(merged_check, ptm_check, remove_keys=remove_keys) 59 | 60 | image_encoder = torch.load(pretrained_checkpoint) 61 | image_encoder.load_state_dict(merged_state_dict, strict=False) 62 | 63 | Total_ACC = 0. 64 | for dataset in exam_datasets: 65 | metrics = eval_single_dataset(image_encoder, dataset, args) 66 | Total_ACC += metrics['top1'] 67 | log.info(str(dataset) + ':' + str(metrics)) 68 | 69 | log.info('Final: ' + 'Avg ACC:' + str(Total_ACC / len(exam_datasets))) 70 | -------------------------------------------------------------------------------- /src/modeling.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | import open_clip 4 | 5 | import utils 6 | 7 | 8 | class ImageEncoder(torch.nn.Module): 9 | def __init__(self, args, keep_lang=False): 10 | super().__init__() 11 | 12 | print(f'Loading {args.model} pre-trained weights.') 13 | if '__pretrained__' in args.model: 14 | name, pretrained = args.model.split('__pretrained__') 15 | else: 16 | name = args.model 17 | pretrained = 'openai' 18 | self.model, self.train_preprocess, self.val_preprocess = open_clip.create_model_and_transforms( 19 | name, pretrained=pretrained, cache_dir=args.openclip_cachedir) 20 | 21 | self.cache_dir = args.cache_dir 22 | 23 | if not keep_lang and hasattr(self.model, 'transformer'): 24 | delattr(self.model, 'transformer') 25 | 26 | def forward(self, images): 27 | assert self.model is not None 28 | return self.model.encode_image(images) 29 | 30 | def __call__(self, inputs): 31 | return self.forward(inputs) 32 | 33 | def save(self, filename): 34 | print(f'Saving image encoder to {filename}') 35 | utils.torch_save(self, filename) 36 | 37 | @classmethod 38 | def load(cls, model_name, filename): 39 | print(f'Loading image encoder from {filename}') 40 | state_dict = torch.load(filename) 41 | return cls.load(model_name, state_dict) 42 | 43 | @classmethod 44 | def load_from_state_dict(cls, model_name, state_dict): 45 | self.model, self.train_preprocess, self.val_preprocess = open_clip.create_model_and_transforms( 46 | name, pretrained=pretrained, cache_dir=args.openclip_cachedir) 47 | self.model.load_from_state_dict(state_dict) 48 | 49 | 50 | 51 | 52 | class ClassificationHead(torch.nn.Linear): 53 | def __init__(self, normalize, weights, biases=None): 54 | output_size, input_size = weights.shape 55 | super().__init__(input_size, output_size) 56 | self.normalize = normalize 57 | if weights is not None: 58 | self.weight = torch.nn.Parameter(weights.clone()) 59 | if biases is not None: 60 | self.bias = torch.nn.Parameter(biases.clone()) 61 | else: 62 | self.bias = torch.nn.Parameter(torch.zeros_like(self.bias)) 63 | 64 | def forward(self, inputs): 65 | if self.normalize: 66 | inputs = inputs / inputs.norm(dim=-1, keepdim=True) 67 | return super().forward(inputs) 68 | 69 | def __call__(self, inputs): 70 | return self.forward(inputs) 71 | 72 | def save(self, filename): 73 | print(f'Saving classification head to {filename}') 74 | utils.torch_save(self, filename) 75 | 76 | @classmethod 77 | def load(cls, filename): 78 | print(f'Loading classification head from {filename}') 79 | return utils.torch_load(filename) 80 | 81 | 82 | class ImageClassifier(torch.nn.Module): 83 | def __init__(self, image_encoder, classification_head): 84 | super().__init__() 85 | self.image_encoder = image_encoder 86 | self.classification_head = classification_head 87 | if self.image_encoder is not None: 88 | if hasattr(self.image_encoder, 'train_preprocess'): 89 | self.train_preprocess = self.image_encoder.train_preprocess 90 | self.val_preprocess = self.image_encoder.val_preprocess 91 | elif hasattr(self.image_encoder.model, 'train_preprocess'): 92 | self.train_preprocess = self.image_encoder.model.train_preprocess 93 | self.val_preprocess = self.image_encoder.model.val_preprocess 94 | 95 | def freeze_head(self): 96 | self.classification_head.weight.requires_grad_(False) 97 | self.classification_head.bias.requires_grad_(False) 98 | 99 | def forward(self, inputs): 100 | features = self.image_encoder(inputs) 101 | outputs = self.classification_head(features) 102 | return outputs 103 | 104 | def __call__(self, inputs): 105 | return self.forward(inputs) 106 | 107 | def save(self, filename): 108 | print(f'Saving image classifier to {filename}') 109 | utils.torch_save(self, filename) 110 | 111 | @classmethod 112 | def load(cls, filename): 113 | print(f'Loading image classifier from {filename}') 114 | return utils.torch_load(filename) 115 | 116 | class ImageClassifier_debug(torch.nn.Module): 117 | def __init__(self, image_encoder, image_encoder2, classification_head): 118 | super().__init__() 119 | self.image_encoder = image_encoder 120 | self.image_encoder2 = image_encoder2 121 | self.classification_head = classification_head 122 | if self.image_encoder is not None: 123 | self.train_preprocess = self.image_encoder.train_preprocess 124 | self.val_preprocess = self.image_encoder.val_preprocess 125 | 126 | def freeze_head(self): 127 | self.classification_head.weight.requires_grad_(False) 128 | self.classification_head.bias.requires_grad_(False) 129 | 130 | def forward(self, inputs): 131 | features = self.image_encoder(inputs) 132 | features2 = self.image_encoder2(inputs) 133 | outputs = self.classification_head(features + features2) 134 | return outputs 135 | 136 | def __call__(self, inputs): 137 | return self.forward(inputs) 138 | 139 | def save(self, filename): 140 | print(f'Saving image classifier to {filename}') 141 | utils.torch_save(self, filename) 142 | 143 | @classmethod 144 | def load(cls, filename): 145 | print(f'Loading image classifier from {filename}') 146 | return utils.torch_load(filename) 147 | 148 | class MultiHeadImageClassifier(torch.nn.Module): 149 | def __init__(self, image_encoder, classification_heads): 150 | super().__init__() 151 | self.image_encoder = image_encoder 152 | self.classification_heads = torch.nn.ModuleList(classification_heads) 153 | if self.image_encoder is not None: 154 | self.train_preprocess = self.image_encoder.train_preprocess 155 | self.val_preprocess = self.image_encoder.val_preprocess 156 | 157 | def freeze_head(self): 158 | for idx in range(len(self.classification_heads)): 159 | self.classification_heads[idx].weight.requires_grad_(False) 160 | self.classification_heads[idx].bias.requires_grad_(False) 161 | 162 | def forward(self, inputs, head_idx): 163 | features = self.image_encoder(inputs) 164 | outputs = self.classification_heads[head_idx](features) 165 | return outputs 166 | 167 | def __call__(self, inputs, head_idx): 168 | return self.forward(inputs, head_idx) 169 | 170 | def save(self, filename): 171 | print(f'Saving image classifier to {filename}') 172 | utils.torch_save(self, filename) 173 | 174 | @classmethod 175 | def load(cls, filename): 176 | print(f'Loading image classifier from {filename}') 177 | return utils.torch_load(filename) 178 | -------------------------------------------------------------------------------- /src/task_vectors.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | class TaskVector(): 5 | def __init__(self, pretrained_checkpoint=None, finetuned_checkpoint=None, vector=None): 6 | """Initializes the task vector from a pretrained and a finetuned checkpoints. 7 | 8 | This can either be done by passing two state dicts (one corresponding to the 9 | pretrained model, and another to the finetuned model), or by directly passying in 10 | the task vector state dict. 11 | """ 12 | if vector is not None: 13 | self.vector = vector 14 | else: 15 | assert pretrained_checkpoint is not None and finetuned_checkpoint is not None 16 | with torch.no_grad(): 17 | print('TaskVector:' + finetuned_checkpoint) 18 | pretrained_state_dict = torch.load(pretrained_checkpoint).state_dict() 19 | finetuned_state_dict = torch.load(finetuned_checkpoint).state_dict() 20 | self.vector = {} 21 | for key in pretrained_state_dict: 22 | if pretrained_state_dict[key].dtype in [torch.int64, torch.uint8]: 23 | continue 24 | self.vector[key] = finetuned_state_dict[key] - pretrained_state_dict[key] 25 | 26 | def __add__(self, other): 27 | """Add two task vectors together.""" 28 | with torch.no_grad(): 29 | new_vector = {} 30 | for key in self.vector: 31 | if key not in other.vector: 32 | print(f'Warning, key {key} is not present in both task vectors.') 33 | continue 34 | new_vector[key] = self.vector[key] + other.vector[key] 35 | return TaskVector(vector=new_vector) 36 | 37 | def __radd__(self, other): 38 | if other is None or isinstance(other, int): 39 | return self 40 | return self.__add__(other) 41 | 42 | def __neg__(self): 43 | """Negate a task vector.""" 44 | with torch.no_grad(): 45 | new_vector = {} 46 | for key in self.vector: 47 | new_vector[key] = - self.vector[key] 48 | return TaskVector(vector=new_vector) 49 | 50 | def weightmerging(self, taskvectors, coefficients): 51 | with torch.no_grad(): 52 | new_vector = {} 53 | for key in taskvectors[0].vector: 54 | new_vector[key] = sum(coefficients[k] * taskvectors[k][key] for k in range(len(taskvectors))) 55 | return TaskVector(vector=new_vector) 56 | 57 | def apply_to(self, pretrained_checkpoint, scaling_coef=1.0): 58 | """Apply a task vector to a pretrained model.""" 59 | with torch.no_grad(): 60 | pretrained_model = torch.load(pretrained_checkpoint) 61 | new_state_dict = {} 62 | pretrained_state_dict = pretrained_model.state_dict() 63 | for key in pretrained_state_dict: 64 | if key not in self.vector: 65 | print(f'Warning: key {key} is present in the pretrained state dict but not in the task vector') 66 | continue 67 | new_state_dict[key] = pretrained_state_dict[key] + scaling_coef * self.vector[key] 68 | pretrained_model.load_state_dict(new_state_dict, strict=False) 69 | return pretrained_model 70 | 71 | -------------------------------------------------------------------------------- /src/ties_merging_utils.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os, copy 3 | import torch 4 | import matplotlib.pyplot as plt 5 | import numpy as np 6 | import re 7 | from collections import OrderedDict 8 | import torch.nn.functional as F 9 | # from transformers import AutoTokenizer, AutoModelForSeq2SeqLM 10 | 11 | ## Model conversion utils 12 | def state_dict_to_vector(state_dict, remove_keys=[]): 13 | shared_state_dict = copy.deepcopy(state_dict) 14 | for key in remove_keys: 15 | if key in shared_state_dict: 16 | del shared_state_dict[key] 17 | sorted_shared_state_dict = OrderedDict(sorted(shared_state_dict.items())) 18 | return torch.nn.utils.parameters_to_vector( 19 | [value.reshape(-1) for key, value in sorted_shared_state_dict.items()] 20 | ) 21 | 22 | 23 | def vector_to_state_dict(vector, state_dict, remove_keys=[]): 24 | # create a reference dict to define the order of the vector 25 | reference_dict = copy.deepcopy(state_dict) 26 | for key in remove_keys: 27 | if key in reference_dict: 28 | del reference_dict[key] 29 | sorted_reference_dict = OrderedDict(sorted(reference_dict.items())) 30 | 31 | # create a shared state dict using the refence dict 32 | torch.nn.utils.vector_to_parameters(vector, sorted_reference_dict.values()) 33 | 34 | # add back the encoder and decoder embedding weights. 35 | if "transformer.shared.weight" in sorted_reference_dict: 36 | for key in remove_keys: 37 | sorted_reference_dict[key] = sorted_reference_dict[ 38 | "transformer.shared.weight" 39 | ] 40 | return sorted_reference_dict 41 | 42 | 43 | def add_ptm_to_tv(tv_dict, ptm_dict): 44 | assert set(tv_dict.keys()) == set( 45 | ptm_dict.keys() 46 | ), "Differing parameter names in models." 47 | final_dict = copy.deepcopy(tv_dict) 48 | for k, v in ptm_dict.items(): 49 | final_dict[k] = tv_dict[k] + v 50 | return final_dict 51 | 52 | 53 | def check_parameterNamesMatch(checkpoints): 54 | parameter_names = set(checkpoints[0].keys()) 55 | 56 | if len(checkpoints) >= 2: 57 | # raise ValueError("Number of models is less than 2.") 58 | for checkpoint in checkpoints[1:]: 59 | current_parameterNames = set(checkpoint.keys()) 60 | if current_parameterNames != parameter_names: 61 | raise ValueError( 62 | "Differing parameter names in models. " 63 | f"The different parameters are {parameter_names.symmetric_difference(current_parameterNames)}" 64 | ) 65 | 66 | def check_state_dicts_equal(state_dict1, state_dict2): 67 | if set(state_dict1.keys()) != set(state_dict2.keys()): 68 | return False 69 | 70 | for key in state_dict1.keys(): 71 | if not torch.equal(state_dict1[key], state_dict2[key]): 72 | return False 73 | 74 | return True 75 | 76 | 77 | 78 | ## TIES MERGING UTILS 79 | 80 | def topk_values_mask(M, K=0.7, return_mask=False): 81 | if K > 1: 82 | K /= 100 83 | 84 | original_shape = M.shape 85 | if M.dim() == 1: 86 | M = M.unsqueeze(0) 87 | 88 | n, d = M.shape 89 | k = int(d * K) 90 | k = d - k # Keep top k elements instead of bottom k elements 91 | 92 | # Find the k-th smallest element by magnitude for each row 93 | kth_values, _ = M.abs().kthvalue(k, dim=1, keepdim=True) 94 | # Create a mask tensor with True for the top k elements in each row 95 | mask = M.abs() >= kth_values 96 | final_mask = mask.squeeze() if original_shape == M.squeeze().shape else mask 97 | 98 | if return_mask: 99 | return M * final_mask, final_mask.float().mean(dim=1), final_mask 100 | return M * final_mask, final_mask.float().mean(dim=1) 101 | 102 | 103 | def resolve_zero_signs(sign_to_mult, method="majority"): 104 | majority_sign = torch.sign(sign_to_mult.sum()) 105 | 106 | if method == "majority": 107 | sign_to_mult[sign_to_mult == 0] = majority_sign 108 | elif method == "minority": 109 | sign_to_mult[sign_to_mult == 0] = -1 * majority_sign 110 | return sign_to_mult 111 | 112 | 113 | def resolve_sign(Tensor): 114 | sign_to_mult = torch.sign(Tensor.sum(dim=0)) 115 | sign_to_mult = resolve_zero_signs(sign_to_mult, "majority") 116 | return sign_to_mult 117 | 118 | 119 | def disjoint_merge(Tensor, merge_func, sign_to_mult): 120 | merge_func = merge_func.split("-")[-1] 121 | 122 | # If sign is provided then we select the corresponding entries and aggregate. 123 | if sign_to_mult is not None: 124 | rows_to_keep = torch.where( 125 | sign_to_mult.unsqueeze(0) > 0, Tensor > 0, Tensor < 0 126 | ) 127 | selected_entries = Tensor * rows_to_keep 128 | # Else we select all non-zero entries and aggregate. 129 | else: 130 | rows_to_keep = Tensor != 0 131 | selected_entries = Tensor * rows_to_keep 132 | 133 | if merge_func == "mean": 134 | non_zero_counts = (selected_entries != 0).sum(dim=0).float() 135 | disjoint_aggs = torch.sum(selected_entries, dim=0) / torch.clamp(non_zero_counts, min=1) 136 | elif merge_func == "sum": 137 | disjoint_aggs = torch.sum(selected_entries, dim=0) 138 | elif merge_func == "max": 139 | disjoint_aggs = selected_entries.abs().max(dim=0)[0] 140 | disjoint_aggs *= sign_to_mult 141 | else: 142 | raise ValueError(f"Merge method {merge_func} is not defined.") 143 | 144 | return disjoint_aggs 145 | 146 | 147 | def ties_merging( 148 | flat_task_checks, 149 | reset_thresh=None, 150 | merge_func="", 151 | ): 152 | all_checks = flat_task_checks.clone() 153 | updated_checks, *_ = topk_values_mask( 154 | all_checks, K=reset_thresh, return_mask=False 155 | ) 156 | print(f"RESOLVING SIGN") 157 | final_signs = resolve_sign(updated_checks) 158 | assert final_signs is not None 159 | 160 | print(f"Disjoint AGGREGATION: {merge_func}") 161 | merged_tv = disjoint_merge(updated_checks, merge_func, final_signs) 162 | 163 | return merged_tv 164 | 165 | def disjoint_merge_split(Tensor, merge_func, sign_to_mult): 166 | merge_func = merge_func.split("-")[-1] 167 | 168 | # If sign is provided then we select the corresponding entries and aggregate. 169 | if sign_to_mult is not None: 170 | rows_to_keep = torch.where( 171 | sign_to_mult.unsqueeze(0) > 0, Tensor > 0, Tensor < 0 172 | ) 173 | selected_entries = Tensor * rows_to_keep 174 | # Else we select all non-zero entries and aggregate. 175 | else: 176 | rows_to_keep = Tensor != 0 177 | selected_entries = Tensor * rows_to_keep 178 | 179 | if merge_func == "sum": 180 | disjoint_aggs = torch.sum(selected_entries, dim=0) 181 | else: 182 | raise ValueError(f"Merge method {merge_func} is not defined.") 183 | 184 | return selected_entries, disjoint_aggs 185 | 186 | 187 | def ties_merging_split( 188 | flat_task_checks, 189 | reset_thresh=None, 190 | merge_func="", 191 | ): 192 | all_checks = flat_task_checks.clone() 193 | updated_checks, *_ = topk_values_mask( 194 | all_checks, K=reset_thresh, return_mask=False 195 | ) 196 | print(f"RESOLVING SIGN") 197 | final_signs = resolve_sign(updated_checks) 198 | assert final_signs is not None 199 | 200 | print(f"Disjoint AGGREGATION: {merge_func}") 201 | selected_entries, merged_tv = disjoint_merge_split(updated_checks, merge_func, final_signs) 202 | 203 | return selected_entries, merged_tv 204 | -------------------------------------------------------------------------------- /src/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import torch 4 | import pickle 5 | from tqdm import tqdm 6 | import math 7 | 8 | import numpy as np 9 | 10 | 11 | def assign_learning_rate(param_group, new_lr): 12 | param_group["lr"] = new_lr 13 | 14 | 15 | def _warmup_lr(base_lr, warmup_length, step): 16 | return base_lr * (step + 1) / warmup_length 17 | 18 | 19 | def cosine_lr(optimizer, base_lrs, warmup_length, steps): 20 | if not isinstance(base_lrs, list): 21 | base_lrs = [base_lrs for _ in optimizer.param_groups] 22 | assert len(base_lrs) == len(optimizer.param_groups) 23 | def _lr_adjuster(step): 24 | for param_group, base_lr in zip(optimizer.param_groups, base_lrs): 25 | if step < warmup_length: 26 | lr = _warmup_lr(base_lr, warmup_length, step) 27 | else: 28 | e = step - warmup_length 29 | es = steps - warmup_length 30 | lr = 0.5 * (1 + np.cos(np.pi * e / es)) * base_lr 31 | assign_learning_rate(param_group, lr) 32 | return _lr_adjuster 33 | 34 | 35 | def accuracy(output, target, topk=(1,)): 36 | pred = output.topk(max(topk), 1, True, True)[1].t() 37 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 38 | return [float(correct[:k].reshape(-1).float().sum(0, keepdim=True).cpu().numpy()) for k in topk] 39 | 40 | 41 | def torch_load_old(save_path, device=None): 42 | with open(save_path, 'rb') as f: 43 | classifier = pickle.load(f) 44 | if device is not None: 45 | classifier = classifier.to(device) 46 | return classifier 47 | 48 | 49 | def torch_save(model, save_path): 50 | if os.path.dirname(save_path) != '': 51 | os.makedirs(os.path.dirname(save_path), exist_ok=True) 52 | torch.save(model.cpu(), save_path) 53 | 54 | 55 | def torch_load(save_path, device=None): 56 | model = torch.load(save_path) 57 | if device is not None: 58 | model = model.to(device) 59 | return model 60 | 61 | 62 | 63 | def get_logits(inputs, classifier): 64 | assert callable(classifier) 65 | if hasattr(classifier, 'to'): 66 | classifier = classifier.to(inputs.device) 67 | return classifier(inputs) 68 | 69 | 70 | def get_probs(inputs, classifier): 71 | if hasattr(classifier, 'predict_proba'): 72 | probs = classifier.predict_proba(inputs.detach().cpu().numpy()) 73 | return torch.from_numpy(probs) 74 | logits = get_logits(inputs, classifier) 75 | return logits.softmax(dim=1) 76 | 77 | 78 | class LabelSmoothing(torch.nn.Module): 79 | def __init__(self, smoothing=0.0): 80 | super(LabelSmoothing, self).__init__() 81 | self.confidence = 1.0 - smoothing 82 | self.smoothing = smoothing 83 | 84 | def forward(self, x, target): 85 | logprobs = torch.nn.functional.log_softmax(x, dim=-1) 86 | 87 | nll_loss = -logprobs.gather(dim=-1, index=target.unsqueeze(1)) 88 | nll_loss = nll_loss.squeeze(1) 89 | smooth_loss = -logprobs.mean(dim=-1) 90 | loss = self.confidence * nll_loss + self.smoothing * smooth_loss 91 | return loss.mean() 92 | --------------------------------------------------------------------------------