├── .DS_Store
├── .gitignore
├── AdaMerging.png
├── LICENSE
├── README.md
├── checkpoints
└── README.md
├── data
└── README.md
├── logs
└── ViT-B-32
│ ├── log_20230920_214801_Layer_wise_AdaMergingPP.txt
│ ├── log_20230921_193533_Layer_wise_AdaMerging.txt
│ ├── log_20230921_202228_Task_wise_AdaMerging.txt
│ └── log_20230921_205130_Task_wise_AdaMergingPP.txt
└── src
├── args.py
├── datasets
├── cars.py
├── common.py
├── dtd.py
├── eurosat.py
├── gtsrb.py
├── mnist.py
├── registry.py
├── resisc45.py
├── sun397.py
├── svhn.py
└── templates.py
├── eval.py
├── heads.py
├── main_layer_wise_adamerging.py
├── main_layer_wise_adamergingpp.py
├── main_task_arithmetic.py
├── main_task_wise_adamerging.py
├── main_task_wise_adamergingpp.py
├── main_ties_merging.py
├── merging_cofficient.py
├── modeling.py
├── task_vectors.py
├── ties_merging_utils.py
└── utils.py
/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/EnnengYang/AdaMerging/cea1165ad900113b8a52ae5ff190626576635f9b/.DS_Store
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 |
2 | .DS_Store
3 |
--------------------------------------------------------------------------------
/AdaMerging.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/EnnengYang/AdaMerging/cea1165ad900113b8a52ae5ff190626576635f9b/AdaMerging.png
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2023 Enneng Yang
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # AdaMerging
2 |
3 | A repository of **'[AdaMerging: Adaptive Model Merging for Multi-Task Learning](https://arxiv.org/abs/2310.02575). ICLR, 2024.'**.
4 |
5 |
6 | ## Abstract
7 | > Multi-task learning (MTL) aims to empower a model to tackle multiple tasks simultaneously. A recent development known as task arithmetic has revealed that several models, each fine-tuned for distinct tasks, can be directly merged into a single model to execute MTL without necessitating a retraining process using the initial training data. Nevertheless, this direct addition of models often leads to a significant deterioration in the overall performance of the merged model. This decline occurs due to potential conflicts and intricate correlations among the multiple tasks. Consequently, the challenge emerges of how to merge pre-trained models more effectively without using their original training data. This paper introduces an innovative technique called Adaptive Model Merging (AdaMerging). This approach aims to autonomously learn the coefficients for model merging, either in a task-wise or layer-wise manner, without relying on the original training data. Specifically, our AdaMerging method operates as an automatic, unsupervised task arithmetic scheme. It leverages entropy minimization on unlabeled test samples from the multi-task setup as a surrogate objective function to iteratively refine the merging coefficients of the multiple models. Our experimental findings across eight tasks demonstrate the efficacy of the AdaMerging scheme we put forth. Compared to the current state-of-the-art (SOTA) task arithmetic merging scheme, AdaMerging showcases a remarkable 11\% improvement in performance. Notably, AdaMerging also exhibits superior generalization capabilities when applied to unseen downstream tasks. Furthermore, it displays a significantly enhanced robustness to data distribution shifts that may occur during the testing phase.
8 |
9 |
10 |
11 |
12 |
13 |
14 | ## Citation
15 | If you find our paper or this resource helpful, please consider cite:
16 | ```
17 | @article{AdaMerging_ICLR_2024,
18 | title={AdaMerging: Adaptive Model Merging for Multi-Task Learning},
19 | author={Yang, Enneng and Wang, Zhenyi and Shen, Li and Liu, Shiwei and Guo, Guibing and Wang, Xingwei and Tao, Dacheng},
20 | journal={The Twelfth International Conference on Learning Representations},
21 | year={2024}
22 | }
23 | ```
24 | Thanks!
25 |
26 |
27 | ## Datasets
28 | Refer to dataset processing in the [task_vectors](https://github.com/mlfoundations/task_vectors).
29 |
30 | Or you can download the processed data from [Baidu Cloud disk](https://pan.baidu.com/s/1w0Z2UVv3NVmqDhjH8WTOJQ?pwd=kvg6).
31 |
32 |
33 | ## Checkpoints
34 |
35 | You can download the fine-tuned checkpoints from the [task_vectors#checkpoints](https://github.com/mlfoundations/task_vectors#checkpoints).
36 | The Google Drive folder is: [task_vectors_checkpoints](https://drive.google.com/drive/folders/1u_Tva6x0p6oxu5Eo0ZZsf-520Cc_3MKw)
37 |
38 | *Note: When using ```torch.load(xxx_checkpoint).state_dict()``` fails, you can try ```pickle.load(open(xxx_checkpoint, 'rb')).state_dict()```.*
39 |
40 | ## Code
41 | > [!Note]
42 | > We noticed that our [AdaMerging](https://github.com/tanganke/fusion_bench/tree/main/fusion_bench/method/adamerging) method has been implemented in [FusionBench](https://github.com/tanganke/fusion_bench), a model merging benchmark. Thanks to its excellent memory management, AdaMerging is highly efficient within this framework, with the entire training process taking **only a few minutes**.
43 |
44 | ### Train
45 |
46 | **If you want to train AdaMerging, run this part of the code. If you want to load the trained merging coefficients directly, refer to the Eval section.**
47 |
48 | First enter the root directory of the source code.
49 | > cd root_path/src/
50 |
51 | Run Task Arithmetic [paper](https://arxiv.org/abs/2212.04089)
52 | > python main_task_arithmetic.py
53 |
54 | Run TIES-MERGING [paper](https://arxiv.org/abs/2306.01708)
55 | > python main_ties_merging.py
56 |
57 | Run Task-wise AdaMerging (Ours)
58 | > python main_task_wise_adamerging.py
59 |
60 | Run Task-wise AdaMerging++ (Ours)
61 | > python main_task_wise_adamergingpp.py
62 |
63 | Run Layer-wise AdaMerging (Ours)
64 | > python main_layer_wise_adamerging.py
65 |
66 | Run Layer-wise AdaMerging++ (Ours)
67 | > python main_layer_wise_adamergingpp.py
68 |
69 | *Note: Due to machine memory limitations, our implementation reloaded the dataset at each step, which resulted in a significant amount of additional time. If your machine has enough memory, you can load all the data before optimizing the merging coefficients, which will speed up the training significantly (i.e., the merging coefficients can be trained in a matter of minutes).*
70 |
71 | ### Eval
72 | Alternatively, you can load our trained merge coefficients, which can be found in the *[merging_coefficient.py](https://github.com/EnnengYang/AdaMerging/blob/main/src/merging_cofficient.py)* file. The general process is as follows:
73 |
74 | ```
75 | # load
76 | from merging_cofficient import get_merging_cofficients
77 | ralpha = get_merging_cofficients(method_name, model_name)
78 | self.alpha = torch.Tensor(ralpha)
79 |
80 | # wrap
81 | if self.alpha.size()[0] == 1:# task-wise merging
82 | params = tuple(sum(tuple(pi * alphai for pi, alphai in zip(p, self.alpha[0].cpu()))) for j, p in enumerate(zip(*self.paramslist)))
83 | else: # layer-wise merging
84 | params = tuple(sum(tuple(pi * alphai for pi, alphai in zip(p, self.alpha[j].cpu()))) for j, p in enumerate(zip(*self.paramslist)))
85 | ```
86 | More details can be found in the following code: https://github.com/EnnengYang/RepresentationSurgery
87 |
88 |
89 |
90 | ## Acknowledgement
91 | Our implementation references the code below, thanks to them.
92 |
93 | - Task Arithmetic: https://github.com/mlfoundations/task_vectors
94 |
95 | - TIES-MERGING: https://github.com/prateeky2806/ties-merging/tree/main
96 |
97 | - Model Soups: https://github.com/mlfoundations/model-soups
98 |
99 | - Tent: https://github.com/DequanWang/tent
100 |
101 |
102 |
103 |
--------------------------------------------------------------------------------
/checkpoints/README.md:
--------------------------------------------------------------------------------
1 | Place checkpoints
2 |
--------------------------------------------------------------------------------
/data/README.md:
--------------------------------------------------------------------------------
1 | Place datasets
2 |
--------------------------------------------------------------------------------
/logs/ViT-B-32/log_20230921_202228_Task_wise_AdaMerging.txt:
--------------------------------------------------------------------------------
1 | [tensor([1.0000, 0.2901, 0.2902, 0.2907, 0.3009, 0.2926, 0.2923, 0.2902, 0.2904])]
2 | Eval: Epoch: 9 dataset: SUN397 ACC: 0.5622768642832001
3 | Eval: Epoch: 9 dataset: Cars ACC: 0.5585126228081084
4 | Eval: Epoch: 9 dataset: RESISC45 ACC: 0.6725396825396825
5 | Eval: Epoch: 9 dataset: EuroSAT ACC: 0.8048148148148148
6 | Eval: Epoch: 9 dataset: SVHN ACC: 0.7983635525507068
7 | Eval: Epoch: 9 dataset: GTSRB ACC: 0.6969912905779889
8 | Eval: Epoch: 9 dataset: MNIST ACC: 0.9713
9 | Eval: Epoch: 9 dataset: DTD ACC: 0.5079787234042553
10 | Eval: Epoch: 9 Avg ACC:0.6965971938723446
11 |
12 | [tensor([1.0000, 0.2805, 0.2808, 0.2830, 0.3047, 0.2870, 0.2844, 0.2812, 0.2810])]
13 | Eval: Epoch: 19 dataset: SUN397 ACC: 0.5705737416402675
14 | Eval: Epoch: 19 dataset: Cars ACC: 0.5647307548812337
15 | Eval: Epoch: 19 dataset: RESISC45 ACC: 0.6753968253968254
16 | Eval: Epoch: 19 dataset: EuroSAT ACC: 0.8218518518518518
17 | Eval: Epoch: 19 dataset: SVHN ACC: 0.795021511985249
18 | Eval: Epoch: 19 dataset: GTSRB ACC: 0.695249406175772
19 | Eval: Epoch: 19 dataset: MNIST ACC: 0.9687
20 | Eval: Epoch: 19 dataset: DTD ACC: 0.5122340425531915
21 | Eval: Epoch: 19 Avg ACC:0.7004697668105488
22 |
23 | [tensor([1.0000, 0.2716, 0.2722, 0.2760, 0.3090, 0.2827, 0.2787, 0.2729, 0.2721])]
24 | Eval: Epoch: 29 dataset: SUN397 ACC: 0.5773620958415046
25 | Eval: Epoch: 29 dataset: Cars ACC: 0.5698296231811963
26 | Eval: Epoch: 29 dataset: RESISC45 ACC: 0.6780952380952381
27 | Eval: Epoch: 29 dataset: EuroSAT ACC: 0.8344444444444444
28 | Eval: Epoch: 29 dataset: SVHN ACC: 0.7926782421634911
29 | Eval: Epoch: 29 dataset: GTSRB ACC: 0.6931908155186065
30 | Eval: Epoch: 29 dataset: MNIST ACC: 0.9668
31 | Eval: Epoch: 29 dataset: DTD ACC: 0.5132978723404256
32 | Eval: Epoch: 29 Avg ACC:0.7032122914481133
33 |
34 | [tensor([1.0000, 0.2634, 0.2642, 0.2689, 0.3120, 0.2806, 0.2740, 0.2655, 0.2637])]
35 | Eval: Epoch: 39 dataset: SUN397 ACC: 0.5817870970986071
36 | Eval: Epoch: 39 dataset: Cars ACC: 0.5730630518592215
37 | Eval: Epoch: 39 dataset: RESISC45 ACC: 0.6796825396825397
38 | Eval: Epoch: 39 dataset: EuroSAT ACC: 0.8466666666666667
39 | Eval: Epoch: 39 dataset: SVHN ACC: 0.7926782421634911
40 | Eval: Epoch: 39 dataset: GTSRB ACC: 0.6938242280285035
41 | Eval: Epoch: 39 dataset: MNIST ACC: 0.9648
42 | Eval: Epoch: 39 dataset: DTD ACC: 0.5138297872340426
43 | Eval: Epoch: 39 Avg ACC:0.705791451591634
44 |
45 | [tensor([1.0000, 0.2565, 0.2565, 0.2627, 0.3134, 0.2777, 0.2702, 0.2586, 0.2565])]
46 | Eval: Epoch: 49 dataset: SUN397 ACC: 0.5861618142505154
47 | Eval: Epoch: 49 dataset: Cars ACC: 0.575425942047009
48 | Eval: Epoch: 49 dataset: RESISC45 ACC: 0.6817460317460318
49 | Eval: Epoch: 49 dataset: EuroSAT ACC: 0.8555555555555555
50 | Eval: Epoch: 49 dataset: SVHN ACC: 0.7913721573448064
51 | Eval: Epoch: 49 dataset: GTSRB ACC: 0.6942992874109264
52 | Eval: Epoch: 49 dataset: MNIST ACC: 0.9618
53 | Eval: Epoch: 49 dataset: DTD ACC: 0.5117021276595745
54 | Eval: Epoch: 49 Avg ACC:0.7072578645018024
55 |
56 | [tensor([1.0000, 0.2499, 0.2495, 0.2588, 0.3129, 0.2748, 0.2689, 0.2522, 0.2500])]
57 | Eval: Epoch: 59 dataset: SUN397 ACC: 0.5901342585608689
58 | Eval: Epoch: 59 dataset: Cars ACC: 0.5767939311030966
59 | Eval: Epoch: 59 dataset: RESISC45 ACC: 0.6846031746031747
60 | Eval: Epoch: 59 dataset: EuroSAT ACC: 0.8596296296296296
61 | Eval: Epoch: 59 dataset: SVHN ACC: 0.7905654578979717
62 | Eval: Epoch: 59 dataset: GTSRB ACC: 0.6953285827395091
63 | Eval: Epoch: 59 dataset: MNIST ACC: 0.9596
64 | Eval: Epoch: 59 dataset: DTD ACC: 0.5122340425531915
65 | Eval: Epoch: 59 Avg ACC:0.7086111346359303
66 |
67 | [tensor([1.0000, 0.2444, 0.2435, 0.2557, 0.3117, 0.2748, 0.2691, 0.2476, 0.2439])]
68 | Eval: Epoch: 69 dataset: SUN397 ACC: 0.5918942022426711
69 | Eval: Epoch: 69 dataset: Cars ACC: 0.5786593707250342
70 | Eval: Epoch: 69 dataset: RESISC45 ACC: 0.6857142857142857
71 | Eval: Epoch: 69 dataset: EuroSAT ACC: 0.8611111111111112
72 | Eval: Epoch: 69 dataset: SVHN ACC: 0.7920251997541488
73 | Eval: Epoch: 69 dataset: GTSRB ACC: 0.6988915281076801
74 | Eval: Epoch: 69 dataset: MNIST ACC: 0.9584
75 | Eval: Epoch: 69 dataset: DTD ACC: 0.5090425531914894
76 | Eval: Epoch: 69 Avg ACC:0.7094672813558025
77 |
78 | [tensor([1.0000, 0.2391, 0.2377, 0.2537, 0.3122, 0.2764, 0.2725, 0.2450, 0.2385])]
79 | Eval: Epoch: 79 dataset: SUN397 ACC: 0.5919947704530598
80 | Eval: Epoch: 79 dataset: Cars ACC: 0.5775401069518716
81 | Eval: Epoch: 79 dataset: RESISC45 ACC: 0.6850793650793651
82 | Eval: Epoch: 79 dataset: EuroSAT ACC: 0.8625925925925926
83 | Eval: Epoch: 79 dataset: SVHN ACC: 0.7949446834665027
84 | Eval: Epoch: 79 dataset: GTSRB ACC: 0.705938242280285
85 | Eval: Epoch: 79 dataset: MNIST ACC: 0.9577
86 | Eval: Epoch: 79 dataset: DTD ACC: 0.5074468085106383
87 | Eval: Epoch: 79 Avg ACC:0.7104045711667895
88 |
89 | [tensor([1.0000, 0.2350, 0.2321, 0.2503, 0.3123, 0.2803, 0.2765, 0.2433, 0.2336])]
90 | Eval: Epoch: 89 dataset: SUN397 ACC: 0.5916930658218937
91 | Eval: Epoch: 89 dataset: Cars ACC: 0.5760477552543216
92 | Eval: Epoch: 89 dataset: RESISC45 ACC: 0.6823809523809524
93 | Eval: Epoch: 89 dataset: EuroSAT ACC: 0.8618518518518519
94 | Eval: Epoch: 89 dataset: SVHN ACC: 0.7998617086662569
95 | Eval: Epoch: 89 dataset: GTSRB ACC: 0.7125890736342043
96 | Eval: Epoch: 89 dataset: MNIST ACC: 0.9572
97 | Eval: Epoch: 89 dataset: DTD ACC: 0.5031914893617021
98 | Eval: Epoch: 89 Avg ACC:0.7106019871213979
99 |
100 | [tensor([1.0000, 0.2311, 0.2270, 0.2478, 0.3134, 0.2835, 0.2806, 0.2420, 0.2294])]
101 | Eval: Epoch: 99 dataset: SUN397 ACC: 0.5907376678232011
102 | Eval: Epoch: 99 dataset: Cars ACC: 0.5733117771421465
103 | Eval: Epoch: 99 dataset: RESISC45 ACC: 0.680952380952381
104 | Eval: Epoch: 99 dataset: EuroSAT ACC: 0.8607407407407407
105 | Eval: Epoch: 99 dataset: SVHN ACC: 0.8045482483097726
106 | Eval: Epoch: 99 dataset: GTSRB ACC: 0.7191607284243864
107 | Eval: Epoch: 99 dataset: MNIST ACC: 0.9567
108 | Eval: Epoch: 99 dataset: DTD ACC: 0.5015957446808511
109 | Eval: Epoch: 99 Avg ACC:0.7109684110091848
110 |
111 | [tensor([1.0000, 0.2279, 0.2223, 0.2460, 0.3147, 0.2856, 0.2859, 0.2405, 0.2260])]
112 | Eval: Epoch: 109 dataset: SUN397 ACC: 0.590285110876452
113 | Eval: Epoch: 109 dataset: Cars ACC: 0.5715707001616714
114 | Eval: Epoch: 109 dataset: RESISC45 ACC: 0.6785714285714286
115 | Eval: Epoch: 109 dataset: EuroSAT ACC: 0.8618518518518519
116 | Eval: Epoch: 109 dataset: SVHN ACC: 0.8073140749846343
117 | Eval: Epoch: 109 dataset: GTSRB ACC: 0.7269200316706255
118 | Eval: Epoch: 109 dataset: MNIST ACC: 0.9556
119 | Eval: Epoch: 109 dataset: DTD ACC: 0.5
120 | Eval: Epoch: 109 Avg ACC:0.7115141497645829
121 |
122 | [tensor([1.0000, 0.2259, 0.2190, 0.2443, 0.3152, 0.2878, 0.2920, 0.2400, 0.2241])]
123 | Eval: Epoch: 119 dataset: SUN397 ACC: 0.5889274400362046
124 | Eval: Epoch: 119 dataset: Cars ACC: 0.5694565352568088
125 | Eval: Epoch: 119 dataset: RESISC45 ACC: 0.6771428571428572
126 | Eval: Epoch: 119 dataset: EuroSAT ACC: 0.8607407407407407
127 | Eval: Epoch: 119 dataset: SVHN ACC: 0.8103103872157345
128 | Eval: Epoch: 119 dataset: GTSRB ACC: 0.7344418052256532
129 | Eval: Epoch: 119 dataset: MNIST ACC: 0.9552
130 | Eval: Epoch: 119 dataset: DTD ACC: 0.49627659574468086
131 | Eval: Epoch: 119 Avg ACC:0.711562045170335
132 |
133 | [tensor([1.0000, 0.2237, 0.2146, 0.2460, 0.3189, 0.2891, 0.2978, 0.2375, 0.2221])]
134 | Eval: Epoch: 129 dataset: SUN397 ACC: 0.5877206215115403
135 | Eval: Epoch: 129 dataset: Cars ACC: 0.5668449197860963
136 | Eval: Epoch: 129 dataset: RESISC45 ACC: 0.6757142857142857
137 | Eval: Epoch: 129 dataset: EuroSAT ACC: 0.8633333333333333
138 | Eval: Epoch: 129 dataset: SVHN ACC: 0.8115780577750461
139 | Eval: Epoch: 129 dataset: GTSRB ACC: 0.741409342834521
140 | Eval: Epoch: 129 dataset: MNIST ACC: 0.9541
141 | Eval: Epoch: 129 dataset: DTD ACC: 0.4946808510638298
142 | Eval: Epoch: 129 Avg ACC:0.7119226765023314
143 |
144 | [tensor([1.0000, 0.2223, 0.2108, 0.2463, 0.3256, 0.2921, 0.3023, 0.2342, 0.2199])]
145 | Eval: Epoch: 139 dataset: SUN397 ACC: 0.5865138029868758
146 | Eval: Epoch: 139 dataset: Cars ACC: 0.5651038428056212
147 | Eval: Epoch: 139 dataset: RESISC45 ACC: 0.6726984126984127
148 | Eval: Epoch: 139 dataset: EuroSAT ACC: 0.8659259259259259
149 | Eval: Epoch: 139 dataset: SVHN ACC: 0.8131530424093424
150 | Eval: Epoch: 139 dataset: GTSRB ACC: 0.7454473475851148
151 | Eval: Epoch: 139 dataset: MNIST ACC: 0.9517
152 | Eval: Epoch: 139 dataset: DTD ACC: 0.49148936170212765
153 | Eval: Epoch: 139 Avg ACC:0.7115039670141775
154 |
155 | [tensor([1.0000, 0.2210, 0.2071, 0.2469, 0.3293, 0.2918, 0.3060, 0.2291, 0.2173])]
156 | Eval: Epoch: 149 dataset: SUN397 ACC: 0.5866143711972646
157 | Eval: Epoch: 149 dataset: Cars ACC: 0.5639845790324587
158 | Eval: Epoch: 149 dataset: RESISC45 ACC: 0.6733333333333333
159 | Eval: Epoch: 149 dataset: EuroSAT ACC: 0.8685185185185185
160 | Eval: Epoch: 149 dataset: SVHN ACC: 0.8123079287031346
161 | Eval: Epoch: 149 dataset: GTSRB ACC: 0.7513064133016627
162 | Eval: Epoch: 149 dataset: MNIST ACC: 0.9493
163 | Eval: Epoch: 149 dataset: DTD ACC: 0.48829787234042554
164 | Eval: Epoch: 149 Avg ACC:0.7117078770533498
165 |
166 | [tensor([1.0000, 0.2190, 0.2033, 0.2497, 0.3315, 0.2897, 0.3085, 0.2263, 0.2147])]
167 | Eval: Epoch: 159 dataset: SUN397 ACC: 0.5871674963544024
168 | Eval: Epoch: 159 dataset: Cars ACC: 0.5637358537495336
169 | Eval: Epoch: 159 dataset: RESISC45 ACC: 0.6763492063492064
170 | Eval: Epoch: 159 dataset: EuroSAT ACC: 0.8729629629629629
171 | Eval: Epoch: 159 dataset: SVHN ACC: 0.8099262446220037
172 | Eval: Epoch: 159 dataset: GTSRB ACC: 0.7548693586698337
173 | Eval: Epoch: 159 dataset: MNIST ACC: 0.9475
174 | Eval: Epoch: 159 dataset: DTD ACC: 0.48829787234042554
175 | Eval: Epoch: 159 Avg ACC:0.7126011243810461
176 |
177 | [tensor([1.0000, 0.2168, 0.2003, 0.2509, 0.3316, 0.2899, 0.3136, 0.2241, 0.2125])]
178 | Eval: Epoch: 169 dataset: SUN397 ACC: 0.5857092573037663
179 | Eval: Epoch: 169 dataset: Cars ACC: 0.562119139410521
180 | Eval: Epoch: 169 dataset: RESISC45 ACC: 0.6768253968253968
181 | Eval: Epoch: 169 dataset: EuroSAT ACC: 0.8725925925925926
182 | Eval: Epoch: 169 dataset: SVHN ACC: 0.8104256299938537
183 | Eval: Epoch: 169 dataset: GTSRB ACC: 0.7622327790973872
184 | Eval: Epoch: 169 dataset: MNIST ACC: 0.9461
185 | Eval: Epoch: 169 dataset: DTD ACC: 0.48723404255319147
186 | Eval: Epoch: 169 Avg ACC:0.7129048547220886
187 |
188 | [tensor([1.0000, 0.2165, 0.1982, 0.2516, 0.3297, 0.2904, 0.3178, 0.2235, 0.2115])]
189 | Eval: Epoch: 179 dataset: SUN397 ACC: 0.5854075526726001
190 | Eval: Epoch: 179 dataset: Cars ACC: 0.5603780624300461
191 | Eval: Epoch: 179 dataset: RESISC45 ACC: 0.6765079365079365
192 | Eval: Epoch: 179 dataset: EuroSAT ACC: 0.8696296296296296
193 | Eval: Epoch: 179 dataset: SVHN ACC: 0.8113475722188076
194 | Eval: Epoch: 179 dataset: GTSRB ACC: 0.7672209026128266
195 | Eval: Epoch: 179 dataset: MNIST ACC: 0.9458
196 | Eval: Epoch: 179 dataset: DTD ACC: 0.48723404255319147
197 | Eval: Epoch: 179 Avg ACC:0.7129407123281297
198 |
199 | [tensor([1.0000, 0.2144, 0.1972, 0.2519, 0.3271, 0.2915, 0.3227, 0.2221, 0.2098])]
200 | Eval: Epoch: 189 dataset: SUN397 ACC: 0.5843515864635189
201 | Eval: Epoch: 189 dataset: Cars ACC: 0.5592587986568834
202 | Eval: Epoch: 189 dataset: RESISC45 ACC: 0.676031746031746
203 | Eval: Epoch: 189 dataset: EuroSAT ACC: 0.8662962962962963
204 | Eval: Epoch: 189 dataset: SVHN ACC: 0.81399815611555
205 | Eval: Epoch: 189 dataset: GTSRB ACC: 0.7746634996041172
206 | Eval: Epoch: 189 dataset: MNIST ACC: 0.945
207 | Eval: Epoch: 189 dataset: DTD ACC: 0.48723404255319147
208 | Eval: Epoch: 189 Avg ACC:0.7133542657151629
209 |
210 | [tensor([1.0000, 0.2125, 0.1959, 0.2532, 0.3272, 0.2929, 0.3287, 0.2200, 0.2088])]
211 | Eval: Epoch: 199 dataset: SUN397 ACC: 0.5827927792024941
212 | Eval: Epoch: 199 dataset: Cars ACC: 0.5572689963934834
213 | Eval: Epoch: 199 dataset: RESISC45 ACC: 0.6765079365079365
214 | Eval: Epoch: 199 dataset: EuroSAT ACC: 0.8659259259259259
215 | Eval: Epoch: 199 dataset: SVHN ACC: 0.8161493546404426
216 | Eval: Epoch: 199 dataset: GTSRB ACC: 0.780443388756928
217 | Eval: Epoch: 199 dataset: MNIST ACC: 0.9435
218 | Eval: Epoch: 199 dataset: DTD ACC: 0.4856382978723404
219 | Eval: Epoch: 199 Avg ACC:0.7135283349124439
220 |
221 | [tensor([1.0000, 0.2113, 0.1929, 0.2542, 0.3276, 0.2950, 0.3328, 0.2176, 0.2084])]
222 | Eval: Epoch: 209 dataset: SUN397 ACC: 0.5821390858349675
223 | Eval: Epoch: 209 dataset: Cars ACC: 0.5565228205447084
224 | Eval: Epoch: 209 dataset: RESISC45 ACC: 0.676031746031746
225 | Eval: Epoch: 209 dataset: EuroSAT ACC: 0.8651851851851852
226 | Eval: Epoch: 209 dataset: SVHN ACC: 0.8174938537185003
227 | Eval: Epoch: 209 dataset: GTSRB ACC: 0.7844022169437846
228 | Eval: Epoch: 209 dataset: MNIST ACC: 0.942
229 | Eval: Epoch: 209 dataset: DTD ACC: 0.4851063829787234
230 | Eval: Epoch: 209 Avg ACC:0.7136101614047019
231 |
232 | [tensor([1.0000, 0.2099, 0.1892, 0.2544, 0.3261, 0.2960, 0.3377, 0.2155, 0.2074])]
233 | Eval: Epoch: 219 dataset: SUN397 ACC: 0.5813848242570523
234 | Eval: Epoch: 219 dataset: Cars ACC: 0.5551548314886208
235 | Eval: Epoch: 219 dataset: RESISC45 ACC: 0.6768253968253968
236 | Eval: Epoch: 219 dataset: EuroSAT ACC: 0.8622222222222222
237 | Eval: Epoch: 219 dataset: SVHN ACC: 0.8196834665027658
238 | Eval: Epoch: 219 dataset: GTSRB ACC: 0.789944576405384
239 | Eval: Epoch: 219 dataset: MNIST ACC: 0.941
240 | Eval: Epoch: 219 dataset: DTD ACC: 0.48351063829787233
241 | Eval: Epoch: 219 Avg ACC:0.7137157444999143
242 |
243 | [tensor([1.0000, 0.2098, 0.1874, 0.2561, 0.3247, 0.2958, 0.3429, 0.2145, 0.2068])]
244 | Eval: Epoch: 229 dataset: SUN397 ACC: 0.5807814149947201
245 | Eval: Epoch: 229 dataset: Cars ACC: 0.5539112050739958
246 | Eval: Epoch: 229 dataset: RESISC45 ACC: 0.677936507936508
247 | Eval: Epoch: 229 dataset: EuroSAT ACC: 0.8603703703703703
248 | Eval: Epoch: 229 dataset: SVHN ACC: 0.8201444376152428
249 | Eval: Epoch: 229 dataset: GTSRB ACC: 0.795882818685669
250 | Eval: Epoch: 229 dataset: MNIST ACC: 0.9403
251 | Eval: Epoch: 229 dataset: DTD ACC: 0.4803191489361702
252 | Eval: Epoch: 229 Avg ACC:0.7137057379515845
253 |
254 | [tensor([1.0000, 0.2103, 0.1853, 0.2589, 0.3245, 0.2945, 0.3494, 0.2131, 0.2062])]
255 | Eval: Epoch: 239 dataset: SUN397 ACC: 0.5794740282596671
256 | Eval: Epoch: 239 dataset: Cars ACC: 0.5519214028105957
257 | Eval: Epoch: 239 dataset: RESISC45 ACC: 0.6788888888888889
258 | Eval: Epoch: 239 dataset: EuroSAT ACC: 0.8596296296296296
259 | Eval: Epoch: 239 dataset: SVHN ACC: 0.819299323909035
260 | Eval: Epoch: 239 dataset: GTSRB ACC: 0.801187648456057
261 | Eval: Epoch: 239 dataset: MNIST ACC: 0.9383
262 | Eval: Epoch: 239 dataset: DTD ACC: 0.4776595744680851
263 | Eval: Epoch: 239 Avg ACC:0.7132950620527447
264 |
265 | [tensor([1.0000, 0.2109, 0.1836, 0.2605, 0.3229, 0.2918, 0.3555, 0.2109, 0.2054])]
266 | Eval: Epoch: 249 dataset: SUN397 ACC: 0.5795243123648615
267 | Eval: Epoch: 249 dataset: Cars ACC: 0.5511752269618206
268 | Eval: Epoch: 249 dataset: RESISC45 ACC: 0.6803174603174603
269 | Eval: Epoch: 249 dataset: EuroSAT ACC: 0.8588888888888889
270 | Eval: Epoch: 249 dataset: SVHN ACC: 0.8172249539028887
271 | Eval: Epoch: 249 dataset: GTSRB ACC: 0.8079176563737134
272 | Eval: Epoch: 249 dataset: MNIST ACC: 0.9366
273 | Eval: Epoch: 249 dataset: DTD ACC: 0.47606382978723405
274 | Eval: Epoch: 249 Avg ACC:0.7134640410746085
275 |
276 | [tensor([1.0000, 0.2100, 0.1804, 0.2611, 0.3224, 0.2915, 0.3598, 0.2104, 0.2035])]
277 | Eval: Epoch: 259 dataset: SUN397 ACC: 0.5788203348921406
278 | Eval: Epoch: 259 dataset: Cars ACC: 0.5504290511130456
279 | Eval: Epoch: 259 dataset: RESISC45 ACC: 0.6807936507936508
280 | Eval: Epoch: 259 dataset: EuroSAT ACC: 0.8577777777777778
281 | Eval: Epoch: 259 dataset: SVHN ACC: 0.817839582052858
282 | Eval: Epoch: 259 dataset: GTSRB ACC: 0.8131433095803642
283 | Eval: Epoch: 259 dataset: MNIST ACC: 0.9363
284 | Eval: Epoch: 259 dataset: DTD ACC: 0.474468085106383
285 | Eval: Epoch: 259 Avg ACC:0.7136964739145275
286 |
287 | [tensor([1.0000, 0.2102, 0.1765, 0.2594, 0.3206, 0.2938, 0.3637, 0.2105, 0.2014])]
288 | Eval: Epoch: 269 dataset: SUN397 ACC: 0.5783677779453915
289 | Eval: Epoch: 269 dataset: Cars ACC: 0.5488123367740331
290 | Eval: Epoch: 269 dataset: RESISC45 ACC: 0.6795238095238095
291 | Eval: Epoch: 269 dataset: EuroSAT ACC: 0.8522222222222222
292 | Eval: Epoch: 269 dataset: SVHN ACC: 0.821220036877689
293 | Eval: Epoch: 269 dataset: GTSRB ACC: 0.8174980205859066
294 | Eval: Epoch: 269 dataset: MNIST ACC: 0.9363
295 | Eval: Epoch: 269 dataset: DTD ACC: 0.4728723404255319
296 | Eval: Epoch: 269 Avg ACC:0.713352068044323
297 |
298 | [tensor([1.0000, 0.2110, 0.1739, 0.2590, 0.3193, 0.2976, 0.3662, 0.2108, 0.2005])]
299 | Eval: Epoch: 279 dataset: SUN397 ACC: 0.5775632322622819
300 | Eval: Epoch: 279 dataset: Cars ACC: 0.546325083944783
301 | Eval: Epoch: 279 dataset: RESISC45 ACC: 0.6776190476190476
302 | Eval: Epoch: 279 dataset: EuroSAT ACC: 0.8466666666666667
303 | Eval: Epoch: 279 dataset: SVHN ACC: 0.8248309772587584
304 | Eval: Epoch: 279 dataset: GTSRB ACC: 0.8192399049881235
305 | Eval: Epoch: 279 dataset: MNIST ACC: 0.9366
306 | Eval: Epoch: 279 dataset: DTD ACC: 0.47074468085106386
307 | Eval: Epoch: 279 Avg ACC:0.7124486991988407
308 |
309 | [tensor([1.0000, 0.2103, 0.1713, 0.2608, 0.3216, 0.2974, 0.3687, 0.2081, 0.1986])]
310 | Eval: Epoch: 289 dataset: SUN397 ACC: 0.5773118117363102
311 | Eval: Epoch: 289 dataset: Cars ACC: 0.545330182813083
312 | Eval: Epoch: 289 dataset: RESISC45 ACC: 0.6792063492063493
313 | Eval: Epoch: 289 dataset: EuroSAT ACC: 0.85
314 | Eval: Epoch: 289 dataset: SVHN ACC: 0.8242547633681623
315 | Eval: Epoch: 289 dataset: GTSRB ACC: 0.8209817893903405
316 | Eval: Epoch: 289 dataset: MNIST ACC: 0.9344
317 | Eval: Epoch: 289 dataset: DTD ACC: 0.46702127659574466
318 | Eval: Epoch: 289 Avg ACC:0.7123132716387488
319 |
320 | [tensor([1.0000, 0.2105, 0.1686, 0.2638, 0.3234, 0.2978, 0.3720, 0.2070, 0.1964])]
321 | Eval: Epoch: 299 dataset: SUN397 ACC: 0.5765072660532006
322 | Eval: Epoch: 299 dataset: Cars ACC: 0.5432160179082204
323 | Eval: Epoch: 299 dataset: RESISC45 ACC: 0.6795238095238095
324 | Eval: Epoch: 299 dataset: EuroSAT ACC: 0.8514814814814815
325 | Eval: Epoch: 299 dataset: SVHN ACC: 0.8242931776275353
326 | Eval: Epoch: 299 dataset: GTSRB ACC: 0.8230403800475059
327 | Eval: Epoch: 299 dataset: MNIST ACC: 0.9336
328 | Eval: Epoch: 299 dataset: DTD ACC: 0.4664893617021277
329 | Eval: Epoch: 299 Avg ACC:0.7122689367929852
330 |
331 | [tensor([1.0000, 0.2094, 0.1661, 0.2631, 0.3229, 0.2994, 0.3772, 0.2078, 0.1944])]
332 | Eval: Epoch: 309 dataset: SUN397 ACC: 0.5753507316337306
333 | Eval: Epoch: 309 dataset: Cars ACC: 0.5414749409277453
334 | Eval: Epoch: 309 dataset: RESISC45 ACC: 0.677936507936508
335 | Eval: Epoch: 309 dataset: EuroSAT ACC: 0.847037037037037
336 | Eval: Epoch: 309 dataset: SVHN ACC: 0.8262907191149355
337 | Eval: Epoch: 309 dataset: GTSRB ACC: 0.8274742676167854
338 | Eval: Epoch: 309 dataset: MNIST ACC: 0.9337
339 | Eval: Epoch: 309 dataset: DTD ACC: 0.46117021276595743
340 | Eval: Epoch: 309 Avg ACC:0.7113043021290875
341 |
342 | [tensor([1.0000, 0.2081, 0.1654, 0.2625, 0.3215, 0.3005, 0.3838, 0.2073, 0.1946])]
343 | Eval: Epoch: 319 dataset: SUN397 ACC: 0.5737416402675114
344 | Eval: Epoch: 319 dataset: Cars ACC: 0.5392364133814203
345 | Eval: Epoch: 319 dataset: RESISC45 ACC: 0.6758730158730158
346 | Eval: Epoch: 319 dataset: EuroSAT ACC: 0.8425925925925926
347 | Eval: Epoch: 319 dataset: SVHN ACC: 0.8275583896742471
348 | Eval: Epoch: 319 dataset: GTSRB ACC: 0.8321456848772764
349 | Eval: Epoch: 319 dataset: MNIST ACC: 0.933
350 | Eval: Epoch: 319 dataset: DTD ACC: 0.4622340425531915
351 | Eval: Epoch: 319 Avg ACC:0.7107977224024069
352 |
353 | [tensor([1.0000, 0.2067, 0.1635, 0.2626, 0.3215, 0.3016, 0.3885, 0.2079, 0.1931])]
354 | Eval: Epoch: 329 dataset: SUN397 ACC: 0.5716799919545432
355 | Eval: Epoch: 329 dataset: Cars ACC: 0.5381171496082576
356 | Eval: Epoch: 329 dataset: RESISC45 ACC: 0.6749206349206349
357 | Eval: Epoch: 329 dataset: EuroSAT ACC: 0.84
358 | Eval: Epoch: 329 dataset: SVHN ACC: 0.8284035033804549
359 | Eval: Epoch: 329 dataset: GTSRB ACC: 0.835550277117973
360 | Eval: Epoch: 329 dataset: MNIST ACC: 0.9332
361 | Eval: Epoch: 329 dataset: DTD ACC: 0.4622340425531915
362 | Eval: Epoch: 329 Avg ACC:0.710513199941882
363 |
364 | [tensor([1.0000, 0.2057, 0.1604, 0.2641, 0.3229, 0.2967, 0.3910, 0.2061, 0.1897])]
365 | Eval: Epoch: 339 dataset: SUN397 ACC: 0.572383969427264
366 | Eval: Epoch: 339 dataset: Cars ACC: 0.5376196990424077
367 | Eval: Epoch: 339 dataset: RESISC45 ACC: 0.6765079365079365
368 | Eval: Epoch: 339 dataset: EuroSAT ACC: 0.8433333333333334
369 | Eval: Epoch: 339 dataset: SVHN ACC: 0.8253687768899816
370 | Eval: Epoch: 339 dataset: GTSRB ACC: 0.8382422802850357
371 | Eval: Epoch: 339 dataset: MNIST ACC: 0.9313
372 | Eval: Epoch: 339 dataset: DTD ACC: 0.45797872340425533
373 | Eval: Epoch: 339 Avg ACC:0.7103418398612767
374 |
375 | [tensor([1.0000, 0.2054, 0.1573, 0.2648, 0.3239, 0.2931, 0.3932, 0.2041, 0.1868])]
376 | Eval: Epoch: 349 dataset: SUN397 ACC: 0.573540503846734
377 | Eval: Epoch: 349 dataset: Cars ACC: 0.5381171496082576
378 | Eval: Epoch: 349 dataset: RESISC45 ACC: 0.6773015873015873
379 | Eval: Epoch: 349 dataset: EuroSAT ACC: 0.8455555555555555
380 | Eval: Epoch: 349 dataset: SVHN ACC: 0.822180393362016
381 | Eval: Epoch: 349 dataset: GTSRB ACC: 0.8398258115597783
382 | Eval: Epoch: 349 dataset: MNIST ACC: 0.9298
383 | Eval: Epoch: 349 dataset: DTD ACC: 0.45478723404255317
384 | Eval: Epoch: 349 Avg ACC:0.7101385294095602
385 |
386 | [tensor([1.0000, 0.2083, 0.1568, 0.2709, 0.3245, 0.2906, 0.3951, 0.2031, 0.1869])]
387 | Eval: Epoch: 359 dataset: SUN397 ACC: 0.5740433448986775
388 | Eval: Epoch: 359 dataset: Cars ACC: 0.5373709737594826
389 | Eval: Epoch: 359 dataset: RESISC45 ACC: 0.68
390 | Eval: Epoch: 359 dataset: EuroSAT ACC: 0.847037037037037
391 | Eval: Epoch: 359 dataset: SVHN ACC: 0.8187231100184389
392 | Eval: Epoch: 359 dataset: GTSRB ACC: 0.8408551068883611
393 | Eval: Epoch: 359 dataset: MNIST ACC: 0.9289
394 | Eval: Epoch: 359 dataset: DTD ACC: 0.4531914893617021
395 | Eval: Epoch: 359 Avg ACC:0.7100151327454624
396 |
397 | [tensor([1.0000, 0.2097, 0.1568, 0.2737, 0.3251, 0.2885, 0.3968, 0.2038, 0.1862])]
398 | Eval: Epoch: 369 dataset: SUN397 ACC: 0.5736410720571228
399 | Eval: Epoch: 369 dataset: Cars ACC: 0.5367491605521701
400 | Eval: Epoch: 369 dataset: RESISC45 ACC: 0.6804761904761905
401 | Eval: Epoch: 369 dataset: EuroSAT ACC: 0.8466666666666667
402 | Eval: Epoch: 369 dataset: SVHN ACC: 0.816379840196681
403 | Eval: Epoch: 369 dataset: GTSRB ACC: 0.841409342834521
404 | Eval: Epoch: 369 dataset: MNIST ACC: 0.9287
405 | Eval: Epoch: 369 dataset: DTD ACC: 0.4531914893617021
406 | Eval: Epoch: 369 Avg ACC:0.7096517202681317
407 |
408 | [tensor([1.0000, 0.2091, 0.1553, 0.2723, 0.3230, 0.2890, 0.4026, 0.2043, 0.1841])]
409 | Eval: Epoch: 379 dataset: SUN397 ACC: 0.5726856740584302
410 | Eval: Epoch: 379 dataset: Cars ACC: 0.5352568088546201
411 | Eval: Epoch: 379 dataset: RESISC45 ACC: 0.6801587301587302
412 | Eval: Epoch: 379 dataset: EuroSAT ACC: 0.8433333333333334
413 | Eval: Epoch: 379 dataset: SVHN ACC: 0.8182237246465888
414 | Eval: Epoch: 379 dataset: GTSRB ACC: 0.8460807600950119
415 | Eval: Epoch: 379 dataset: MNIST ACC: 0.9286
416 | Eval: Epoch: 379 dataset: DTD ACC: 0.4526595744680851
417 | Eval: Epoch: 379 Avg ACC:0.70962482570185
418 |
419 | [tensor([1.0000, 0.2077, 0.1540, 0.2692, 0.3209, 0.2916, 0.4059, 0.2037, 0.1824])]
420 | Eval: Epoch: 389 dataset: SUN397 ACC: 0.5717805601649318
421 | Eval: Epoch: 389 dataset: Cars ACC: 0.5346349956473075
422 | Eval: Epoch: 389 dataset: RESISC45 ACC: 0.6784126984126985
423 | Eval: Epoch: 389 dataset: EuroSAT ACC: 0.8392592592592593
424 | Eval: Epoch: 389 dataset: SVHN ACC: 0.8219883220651506
425 | Eval: Epoch: 389 dataset: GTSRB ACC: 0.8486935866983373
426 | Eval: Epoch: 389 dataset: MNIST ACC: 0.9286
427 | Eval: Epoch: 389 dataset: DTD ACC: 0.45159574468085106
428 | Eval: Epoch: 389 Avg ACC:0.7093706458660671
429 |
430 | [tensor([1.0000, 0.2075, 0.1515, 0.2692, 0.3206, 0.2927, 0.4076, 0.2034, 0.1795])]
431 | Eval: Epoch: 399 dataset: SUN397 ACC: 0.5719816965857093
432 | Eval: Epoch: 399 dataset: Cars ACC: 0.533018281308295
433 | Eval: Epoch: 399 dataset: RESISC45 ACC: 0.6785714285714286
434 | Eval: Epoch: 399 dataset: EuroSAT ACC: 0.8381481481481482
435 | Eval: Epoch: 399 dataset: SVHN ACC: 0.8229102642901045
436 | Eval: Epoch: 399 dataset: GTSRB ACC: 0.8501187648456057
437 | Eval: Epoch: 399 dataset: MNIST ACC: 0.9285
438 | Eval: Epoch: 399 dataset: DTD ACC: 0.451063829787234
439 | Eval: Epoch: 399 Avg ACC:0.7092890516920656
440 |
441 | [tensor([1.0000, 0.2086, 0.1492, 0.2702, 0.3197, 0.2921, 0.4047, 0.2030, 0.1794])]
442 | Eval: Epoch: 409 dataset: SUN397 ACC: 0.5729873786895963
443 | Eval: Epoch: 409 dataset: Cars ACC: 0.5331426439497575
444 | Eval: Epoch: 409 dataset: RESISC45 ACC: 0.6798412698412698
445 | Eval: Epoch: 409 dataset: EuroSAT ACC: 0.8381481481481482
446 | Eval: Epoch: 409 dataset: SVHN ACC: 0.8227181929932391
447 | Eval: Epoch: 409 dataset: GTSRB ACC: 0.8481393507521774
448 | Eval: Epoch: 409 dataset: MNIST ACC: 0.9286
449 | Eval: Epoch: 409 dataset: DTD ACC: 0.451063829787234
450 | Eval: Epoch: 409 Avg ACC:0.7093301017701779
451 |
452 | [tensor([1.0000, 0.2098, 0.1449, 0.2720, 0.3202, 0.2917, 0.4005, 0.2022, 0.1792])]
453 | Eval: Epoch: 419 dataset: SUN397 ACC: 0.5760044250012571
454 | Eval: Epoch: 419 dataset: Cars ACC: 0.533018281308295
455 | Eval: Epoch: 419 dataset: RESISC45 ACC: 0.6819047619047619
456 | Eval: Epoch: 419 dataset: EuroSAT ACC: 0.8414814814814815
457 | Eval: Epoch: 419 dataset: SVHN ACC: 0.8224108789182545
458 | Eval: Epoch: 419 dataset: GTSRB ACC: 0.8456848772763262
459 | Eval: Epoch: 419 dataset: MNIST ACC: 0.9283
460 | Eval: Epoch: 419 dataset: DTD ACC: 0.4521276595744681
461 | Eval: Epoch: 419 Avg ACC:0.7101165456831056
462 |
463 | [tensor([1.0000, 0.2116, 0.1421, 0.2741, 0.3199, 0.2931, 0.4009, 0.2016, 0.1778])]
464 | Eval: Epoch: 429 dataset: SUN397 ACC: 0.5763061296324232
465 | Eval: Epoch: 429 dataset: Cars ACC: 0.5317746548936699
466 | Eval: Epoch: 429 dataset: RESISC45 ACC: 0.6823809523809524
467 | Eval: Epoch: 429 dataset: EuroSAT ACC: 0.8414814814814815
468 | Eval: Epoch: 429 dataset: SVHN ACC: 0.8236017209588199
469 | Eval: Epoch: 429 dataset: GTSRB ACC: 0.8456848772763262
470 | Eval: Epoch: 429 dataset: MNIST ACC: 0.9283
471 | Eval: Epoch: 429 dataset: DTD ACC: 0.4521276595744681
472 | Eval: Epoch: 429 Avg ACC:0.7102071845247677
473 |
474 | [tensor([1.0000, 0.2136, 0.1415, 0.2752, 0.3196, 0.2965, 0.4043, 0.2038, 0.1771])]
475 | Eval: Epoch: 439 dataset: SUN397 ACC: 0.5743953336350379
476 | Eval: Epoch: 439 dataset: Cars ACC: 0.5287899514985698
477 | Eval: Epoch: 439 dataset: RESISC45 ACC: 0.6815873015873016
478 | Eval: Epoch: 439 dataset: EuroSAT ACC: 0.8351851851851851
479 | Eval: Epoch: 439 dataset: SVHN ACC: 0.8256760909649662
480 | Eval: Epoch: 439 dataset: GTSRB ACC: 0.8468725257323833
481 | Eval: Epoch: 439 dataset: MNIST ACC: 0.9296
482 | Eval: Epoch: 439 dataset: DTD ACC: 0.45053191489361705
483 | Eval: Epoch: 439 Avg ACC:0.7090797879371326
484 |
485 | [tensor([1.0000, 0.2134, 0.1402, 0.2763, 0.3215, 0.2935, 0.4036, 0.2002, 0.1741])]
486 | Eval: Epoch: 449 dataset: SUN397 ACC: 0.5767084024739779
487 | Eval: Epoch: 449 dataset: Cars ACC: 0.5306553911205074
488 | Eval: Epoch: 449 dataset: RESISC45 ACC: 0.6826984126984127
489 | Eval: Epoch: 449 dataset: EuroSAT ACC: 0.8422222222222222
490 | Eval: Epoch: 449 dataset: SVHN ACC: 0.8232944068838353
491 | Eval: Epoch: 449 dataset: GTSRB ACC: 0.8471100554235946
492 | Eval: Epoch: 449 dataset: MNIST ACC: 0.9272
493 | Eval: Epoch: 449 dataset: DTD ACC: 0.44840425531914896
494 | Eval: Epoch: 449 Avg ACC:0.7097866432677125
495 |
496 | [tensor([1.0000, 0.2137, 0.1394, 0.2744, 0.3232, 0.2922, 0.4023, 0.1989, 0.1706])]
497 | Eval: Epoch: 459 dataset: SUN397 ACC: 0.5781666415246141
498 | Eval: Epoch: 459 dataset: Cars ACC: 0.531525929610745
499 | Eval: Epoch: 459 dataset: RESISC45 ACC: 0.6826984126984127
500 | Eval: Epoch: 459 dataset: EuroSAT ACC: 0.8455555555555555
501 | Eval: Epoch: 459 dataset: SVHN ACC: 0.822180393362016
502 | Eval: Epoch: 459 dataset: GTSRB ACC: 0.846714172604909
503 | Eval: Epoch: 459 dataset: MNIST ACC: 0.9261
504 | Eval: Epoch: 459 dataset: DTD ACC: 0.44893617021276594
505 | Eval: Epoch: 459 Avg ACC:0.7102346594461273
506 |
507 | [tensor([1.0000, 0.2146, 0.1399, 0.2744, 0.3260, 0.2913, 0.4018, 0.1991, 0.1694])]
508 | Eval: Epoch: 469 dataset: SUN397 ACC: 0.5782672097350028
509 | Eval: Epoch: 469 dataset: Cars ACC: 0.5318990175351325
510 | Eval: Epoch: 469 dataset: RESISC45 ACC: 0.6825396825396826
511 | Eval: Epoch: 469 dataset: EuroSAT ACC: 0.8496296296296296
512 | Eval: Epoch: 469 dataset: SVHN ACC: 0.8200291948371236
513 | Eval: Epoch: 469 dataset: GTSRB ACC: 0.846159936658749
514 | Eval: Epoch: 469 dataset: MNIST ACC: 0.9262
515 | Eval: Epoch: 469 dataset: DTD ACC: 0.4478723404255319
516 | Eval: Epoch: 469 Avg ACC:0.7103246264201064
517 |
518 | [tensor([1.0000, 0.2164, 0.1411, 0.2774, 0.3262, 0.2905, 0.4019, 0.1989, 0.1706])]
519 | Eval: Epoch: 479 dataset: SUN397 ACC: 0.5784683461557801
520 | Eval: Epoch: 479 dataset: Cars ACC: 0.5317746548936699
521 | Eval: Epoch: 479 dataset: RESISC45 ACC: 0.6828571428571428
522 | Eval: Epoch: 479 dataset: EuroSAT ACC: 0.8496296296296296
523 | Eval: Epoch: 479 dataset: SVHN ACC: 0.818799938537185
524 | Eval: Epoch: 479 dataset: GTSRB ACC: 0.8458432304038005
525 | Eval: Epoch: 479 dataset: MNIST ACC: 0.9259
526 | Eval: Epoch: 479 dataset: DTD ACC: 0.44840425531914896
527 | Eval: Epoch: 479 Avg ACC:0.7102096497245447
528 |
529 | [tensor([1.0000, 0.2194, 0.1419, 0.2805, 0.3284, 0.2863, 0.4000, 0.1980, 0.1706])]
530 | Eval: Epoch: 489 dataset: SUN397 ACC: 0.5795243123648615
531 | Eval: Epoch: 489 dataset: Cars ACC: 0.5326451933839075
532 | Eval: Epoch: 489 dataset: RESISC45 ACC: 0.6861904761904762
533 | Eval: Epoch: 489 dataset: EuroSAT ACC: 0.8555555555555555
534 | Eval: Epoch: 489 dataset: SVHN ACC: 0.8139213275968039
535 | Eval: Epoch: 489 dataset: GTSRB ACC: 0.8444972288202692
536 | Eval: Epoch: 489 dataset: MNIST ACC: 0.925
537 | Eval: Epoch: 489 dataset: DTD ACC: 0.449468085106383
538 | Eval: Epoch: 489 Avg ACC:0.710850272377282
539 |
540 | [tensor([1.0000, 0.2202, 0.1413, 0.2826, 0.3284, 0.2841, 0.4003, 0.1978, 0.1692])]
541 | Eval: Epoch: 499 dataset: SUN397 ACC: 0.5802785739427767
542 | Eval: Epoch: 499 dataset: Cars ACC: 0.53227210545952
543 | Eval: Epoch: 499 dataset: RESISC45 ACC: 0.6880952380952381
544 | Eval: Epoch: 499 dataset: EuroSAT ACC: 0.857037037037037
545 | Eval: Epoch: 499 dataset: SVHN ACC: 0.8118853718500307
546 | Eval: Epoch: 499 dataset: GTSRB ACC: 0.8448931116389549
547 | Eval: Epoch: 499 dataset: MNIST ACC: 0.9244
548 | Eval: Epoch: 499 dataset: DTD ACC: 0.44840425531914896
549 | Eval: Epoch: 499 Avg ACC:0.7109082116678384
550 |
--------------------------------------------------------------------------------
/logs/ViT-B-32/log_20230921_205130_Task_wise_AdaMergingPP.txt:
--------------------------------------------------------------------------------
1 | Eval: init: dataset: SUN397 ACC: 0.6506260371096696
2 | Eval: init: dataset: Cars ACC: 0.6443228454172366
3 | Eval: init: dataset: RESISC45 ACC: 0.7485714285714286
4 | Eval: init: dataset: EuroSAT ACC: 0.7733333333333333
5 | Eval: init: dataset: SVHN ACC: 0.8126536570374924
6 | Eval: init: dataset: GTSRB ACC: 0.6938242280285035
7 | Eval: init: dataset: MNIST ACC: 0.9652
8 | Eval: init: dataset: DTD ACC: 0.5452127659574468
9 | Eval: init: Avg ACC:0.7292180369318889
10 |
11 | [tensor([1.0000, 0.3099, 0.3098, 0.3097, 0.3097, 0.3097, 0.3098, 0.3098, 0.3100])]
12 | Eval: Epoch: 9 dataset: SUN397 ACC: 0.6471061497460653
13 | Eval: Epoch: 9 dataset: Cars ACC: 0.6422086805123741
14 | Eval: Epoch: 9 dataset: RESISC45 ACC: 0.7473015873015874
15 | Eval: Epoch: 9 dataset: EuroSAT ACC: 0.7762962962962963
16 | Eval: Epoch: 9 dataset: SVHN ACC: 0.8194529809465273
17 | Eval: Epoch: 9 dataset: GTSRB ACC: 0.6992082343626287
18 | Eval: Epoch: 9 dataset: MNIST ACC: 0.9673
19 | Eval: Epoch: 9 dataset: DTD ACC: 0.5425531914893617
20 | Eval: Epoch: 9 Avg ACC:0.7301783900818551
21 |
22 | [tensor([1.0000, 0.3193, 0.3193, 0.3192, 0.3195, 0.3188, 0.3194, 0.3187, 0.3195])]
23 | Eval: Epoch: 19 dataset: SUN397 ACC: 0.6437873988032383
24 | Eval: Epoch: 19 dataset: Cars ACC: 0.6387265265514239
25 | Eval: Epoch: 19 dataset: RESISC45 ACC: 0.7458730158730159
26 | Eval: Epoch: 19 dataset: EuroSAT ACC: 0.7785185185185185
27 | Eval: Epoch: 19 dataset: SVHN ACC: 0.8256376767055931
28 | Eval: Epoch: 19 dataset: GTSRB ACC: 0.7037212984956452
29 | Eval: Epoch: 19 dataset: MNIST ACC: 0.9697
30 | Eval: Epoch: 19 dataset: DTD ACC: 0.5452127659574468
31 | Eval: Epoch: 19 Avg ACC:0.7313971501131101
32 |
33 | [tensor([1.0000, 0.3276, 0.3273, 0.3280, 0.3287, 0.3275, 0.3284, 0.3267, 0.3285])]
34 | Eval: Epoch: 29 dataset: SUN397 ACC: 0.6409714889123548
35 | Eval: Epoch: 29 dataset: Cars ACC: 0.6376072627782614
36 | Eval: Epoch: 29 dataset: RESISC45 ACC: 0.746031746031746
37 | Eval: Epoch: 29 dataset: EuroSAT ACC: 0.7814814814814814
38 | Eval: Epoch: 29 dataset: SVHN ACC: 0.8315150583896742
39 | Eval: Epoch: 29 dataset: GTSRB ACC: 0.7098178939034046
40 | Eval: Epoch: 29 dataset: MNIST ACC: 0.9713
41 | Eval: Epoch: 29 dataset: DTD ACC: 0.5425531914893617
42 | Eval: Epoch: 29 Avg ACC:0.7326597653732856
43 |
44 | [tensor([1.0000, 0.3348, 0.3332, 0.3360, 0.3376, 0.3350, 0.3368, 0.3331, 0.3357])]
45 | Eval: Epoch: 39 dataset: SUN397 ACC: 0.6381555790214714
46 | Eval: Epoch: 39 dataset: Cars ACC: 0.6336276582514613
47 | Eval: Epoch: 39 dataset: RESISC45 ACC: 0.7442857142857143
48 | Eval: Epoch: 39 dataset: EuroSAT ACC: 0.782962962962963
49 | Eval: Epoch: 39 dataset: SVHN ACC: 0.8355485556238476
50 | Eval: Epoch: 39 dataset: GTSRB ACC: 0.7125890736342043
51 | Eval: Epoch: 39 dataset: MNIST ACC: 0.9724
52 | Eval: Epoch: 39 dataset: DTD ACC: 0.5430851063829787
53 | Eval: Epoch: 39 Avg ACC:0.73283183127033
54 |
55 | [tensor([1.0000, 0.3404, 0.3377, 0.3434, 0.3449, 0.3424, 0.3452, 0.3396, 0.3416])]
56 | Eval: Epoch: 49 dataset: SUN397 ACC: 0.6349876803942274
57 | Eval: Epoch: 49 dataset: Cars ACC: 0.6310160427807486
58 | Eval: Epoch: 49 dataset: RESISC45 ACC: 0.7417460317460317
59 | Eval: Epoch: 49 dataset: EuroSAT ACC: 0.782962962962963
60 | Eval: Epoch: 49 dataset: SVHN ACC: 0.8390058389674248
61 | Eval: Epoch: 49 dataset: GTSRB ACC: 0.7167062549485352
62 | Eval: Epoch: 49 dataset: MNIST ACC: 0.9734
63 | Eval: Epoch: 49 dataset: DTD ACC: 0.5425531914893617
64 | Eval: Epoch: 49 Avg ACC:0.7327972504111615
65 |
66 | [tensor([1.0000, 0.3443, 0.3403, 0.3495, 0.3520, 0.3486, 0.3531, 0.3444, 0.3460])]
67 | Eval: Epoch: 59 dataset: SUN397 ACC: 0.6309146678734852
68 | Eval: Epoch: 59 dataset: Cars ACC: 0.6280313393856486
69 | Eval: Epoch: 59 dataset: RESISC45 ACC: 0.7415873015873016
70 | Eval: Epoch: 59 dataset: EuroSAT ACC: 0.7837037037037037
71 | Eval: Epoch: 59 dataset: SVHN ACC: 0.8426167793484942
72 | Eval: Epoch: 59 dataset: GTSRB ACC: 0.7209817893903404
73 | Eval: Epoch: 59 dataset: MNIST ACC: 0.974
74 | Eval: Epoch: 59 dataset: DTD ACC: 0.5414893617021277
75 | Eval: Epoch: 59 Avg ACC:0.7329156178738878
76 |
77 | [tensor([1.0000, 0.3467, 0.3406, 0.3542, 0.3594, 0.3536, 0.3589, 0.3465, 0.3478])]
78 | Eval: Epoch: 69 dataset: SUN397 ACC: 0.6289535877709056
79 | Eval: Epoch: 69 dataset: Cars ACC: 0.6250466359905484
80 | Eval: Epoch: 69 dataset: RESISC45 ACC: 0.7409523809523809
81 | Eval: Epoch: 69 dataset: EuroSAT ACC: 0.7855555555555556
82 | Eval: Epoch: 69 dataset: SVHN ACC: 0.8449216349108789
83 | Eval: Epoch: 69 dataset: GTSRB ACC: 0.7228028503562945
84 | Eval: Epoch: 69 dataset: MNIST ACC: 0.9746
85 | Eval: Epoch: 69 dataset: DTD ACC: 0.5420212765957447
86 | Eval: Epoch: 69 Avg ACC:0.7331067402665385
87 |
88 | [tensor([1.0000, 0.3476, 0.3384, 0.3582, 0.3673, 0.3566, 0.3633, 0.3462, 0.3474])]
89 | Eval: Epoch: 79 dataset: SUN397 ACC: 0.6278976215618243
90 | Eval: Epoch: 79 dataset: Cars ACC: 0.6228081084442234
91 | Eval: Epoch: 79 dataset: RESISC45 ACC: 0.7414285714285714
92 | Eval: Epoch: 79 dataset: EuroSAT ACC: 0.7881481481481482
93 | Eval: Epoch: 79 dataset: SVHN ACC: 0.8458051628764598
94 | Eval: Epoch: 79 dataset: GTSRB ACC: 0.7253365003958828
95 | Eval: Epoch: 79 dataset: MNIST ACC: 0.9746
96 | Eval: Epoch: 79 dataset: DTD ACC: 0.5388297872340425
97 | Eval: Epoch: 79 Avg ACC:0.733106737511144
98 |
99 | [tensor([1.0000, 0.3470, 0.3339, 0.3620, 0.3748, 0.3599, 0.3683, 0.3457, 0.3450])]
100 | Eval: Epoch: 89 dataset: SUN397 ACC: 0.6264896666163826
101 | Eval: Epoch: 89 dataset: Cars ACC: 0.6218132073125233
102 | Eval: Epoch: 89 dataset: RESISC45 ACC: 0.7412698412698413
103 | Eval: Epoch: 89 dataset: EuroSAT ACC: 0.7911111111111111
104 | Eval: Epoch: 89 dataset: SVHN ACC: 0.8471496619545175
105 | Eval: Epoch: 89 dataset: GTSRB ACC: 0.7281076801266825
106 | Eval: Epoch: 89 dataset: MNIST ACC: 0.9746
107 | Eval: Epoch: 89 dataset: DTD ACC: 0.5382978723404256
108 | Eval: Epoch: 89 Avg ACC:0.7336048800914354
109 |
110 | [tensor([1.0000, 0.3472, 0.3296, 0.3660, 0.3814, 0.3640, 0.3740, 0.3456, 0.3443])]
111 | Eval: Epoch: 99 dataset: SUN397 ACC: 0.6250817116709408
112 | Eval: Epoch: 99 dataset: Cars ACC: 0.6195746797661983
113 | Eval: Epoch: 99 dataset: RESISC45 ACC: 0.7412698412698413
114 | Eval: Epoch: 99 dataset: EuroSAT ACC: 0.7937037037037037
115 | Eval: Epoch: 99 dataset: SVHN ACC: 0.8491472034419176
116 | Eval: Epoch: 99 dataset: GTSRB ACC: 0.7309580364212194
117 | Eval: Epoch: 99 dataset: MNIST ACC: 0.9748
118 | Eval: Epoch: 99 dataset: DTD ACC: 0.5361702127659574
119 | Eval: Epoch: 99 Avg ACC:0.7338381736299723
120 |
121 | [tensor([1.0000, 0.3466, 0.3237, 0.3691, 0.3876, 0.3676, 0.3796, 0.3439, 0.3424])]
122 | Eval: Epoch: 109 dataset: SUN397 ACC: 0.6240760295670539
123 | Eval: Epoch: 109 dataset: Cars ACC: 0.6174605148613357
124 | Eval: Epoch: 109 dataset: RESISC45 ACC: 0.7409523809523809
125 | Eval: Epoch: 109 dataset: EuroSAT ACC: 0.7959259259259259
126 | Eval: Epoch: 109 dataset: SVHN ACC: 0.8504532882606023
127 | Eval: Epoch: 109 dataset: GTSRB ACC: 0.7347585114806018
128 | Eval: Epoch: 109 dataset: MNIST ACC: 0.9746
129 | Eval: Epoch: 109 dataset: DTD ACC: 0.5351063829787234
130 | Eval: Epoch: 109 Avg ACC:0.7341666292533279
131 |
132 | [tensor([1.0000, 0.3455, 0.3182, 0.3716, 0.3935, 0.3712, 0.3852, 0.3420, 0.3414])]
133 | Eval: Epoch: 119 dataset: SUN397 ACC: 0.6230200633579726
134 | Eval: Epoch: 119 dataset: Cars ACC: 0.614848899390623
135 | Eval: Epoch: 119 dataset: RESISC45 ACC: 0.7409523809523809
136 | Eval: Epoch: 119 dataset: EuroSAT ACC: 0.797037037037037
137 | Eval: Epoch: 119 dataset: SVHN ACC: 0.8521051014136447
138 | Eval: Epoch: 119 dataset: GTSRB ACC: 0.7380839271575613
139 | Eval: Epoch: 119 dataset: MNIST ACC: 0.9743
140 | Eval: Epoch: 119 dataset: DTD ACC: 0.5351063829787234
141 | Eval: Epoch: 119 Avg ACC:0.7344317240359929
142 |
143 | [tensor([1.0000, 0.3425, 0.3124, 0.3743, 0.3994, 0.3754, 0.3905, 0.3398, 0.3399])]
144 | Eval: Epoch: 129 dataset: SUN397 ACC: 0.622366369990446
145 | Eval: Epoch: 129 dataset: Cars ACC: 0.6131078224101479
146 | Eval: Epoch: 129 dataset: RESISC45 ACC: 0.74
147 | Eval: Epoch: 129 dataset: EuroSAT ACC: 0.8003703703703704
148 | Eval: Epoch: 129 dataset: SVHN ACC: 0.8538721573448064
149 | Eval: Epoch: 129 dataset: GTSRB ACC: 0.7412509897070467
150 | Eval: Epoch: 129 dataset: MNIST ACC: 0.974
151 | Eval: Epoch: 129 dataset: DTD ACC: 0.5319148936170213
152 | Eval: Epoch: 129 Avg ACC:0.7346103254299798
153 |
154 | [tensor([1.0000, 0.3403, 0.3060, 0.3772, 0.4051, 0.3792, 0.3964, 0.3375, 0.3378])]
155 | Eval: Epoch: 139 dataset: SUN397 ACC: 0.6214612560969478
156 | Eval: Epoch: 139 dataset: Cars ACC: 0.6109936575052854
157 | Eval: Epoch: 139 dataset: RESISC45 ACC: 0.7395238095238095
158 | Eval: Epoch: 139 dataset: EuroSAT ACC: 0.8022222222222222
159 | Eval: Epoch: 139 dataset: SVHN ACC: 0.8551014136447449
160 | Eval: Epoch: 139 dataset: GTSRB ACC: 0.7441805225653206
161 | Eval: Epoch: 139 dataset: MNIST ACC: 0.9738
162 | Eval: Epoch: 139 dataset: DTD ACC: 0.5281914893617021
163 | Eval: Epoch: 139 Avg ACC:0.734434296365004
164 |
165 | [tensor([1.0000, 0.3377, 0.2994, 0.3791, 0.4113, 0.3832, 0.4020, 0.3357, 0.3349])]
166 | Eval: Epoch: 149 dataset: SUN397 ACC: 0.6211092673605874
167 | Eval: Epoch: 149 dataset: Cars ACC: 0.6088794926004228
168 | Eval: Epoch: 149 dataset: RESISC45 ACC: 0.7401587301587301
169 | Eval: Epoch: 149 dataset: EuroSAT ACC: 0.8037037037037037
170 | Eval: Epoch: 149 dataset: SVHN ACC: 0.8564074984634297
171 | Eval: Epoch: 149 dataset: GTSRB ACC: 0.7464766429136975
172 | Eval: Epoch: 149 dataset: MNIST ACC: 0.9736
173 | Eval: Epoch: 149 dataset: DTD ACC: 0.527127659574468
174 | Eval: Epoch: 149 Avg ACC:0.73468287434688
175 |
176 | [tensor([1.0000, 0.3358, 0.2940, 0.3806, 0.4176, 0.3850, 0.4085, 0.3323, 0.3329])]
177 | Eval: Epoch: 159 dataset: SUN397 ACC: 0.6205561422034495
178 | Eval: Epoch: 159 dataset: Cars ACC: 0.6071384156199477
179 | Eval: Epoch: 159 dataset: RESISC45 ACC: 0.7403174603174603
180 | Eval: Epoch: 159 dataset: EuroSAT ACC: 0.8062962962962963
181 | Eval: Epoch: 159 dataset: SVHN ACC: 0.8567148125384143
182 | Eval: Epoch: 159 dataset: GTSRB ACC: 0.7494061757719715
183 | Eval: Epoch: 159 dataset: MNIST ACC: 0.9729
184 | Eval: Epoch: 159 dataset: DTD ACC: 0.5260638297872341
185 | Eval: Epoch: 159 Avg ACC:0.7349241415668468
186 |
187 | [tensor([1.0000, 0.3339, 0.2893, 0.3821, 0.4232, 0.3858, 0.4146, 0.3282, 0.3306])]
188 | Eval: Epoch: 169 dataset: SUN397 ACC: 0.6201035852567004
189 | Eval: Epoch: 169 dataset: Cars ACC: 0.6055217012809352
190 | Eval: Epoch: 169 dataset: RESISC45 ACC: 0.7404761904761905
191 | Eval: Epoch: 169 dataset: EuroSAT ACC: 0.8107407407407408
192 | Eval: Epoch: 169 dataset: SVHN ACC: 0.8567916410571604
193 | Eval: Epoch: 169 dataset: GTSRB ACC: 0.7523357086302455
194 | Eval: Epoch: 169 dataset: MNIST ACC: 0.9727
195 | Eval: Epoch: 169 dataset: DTD ACC: 0.5260638297872341
196 | Eval: Epoch: 169 Avg ACC:0.7355916746536508
197 |
198 | [tensor([1.0000, 0.3329, 0.2844, 0.3846, 0.4307, 0.3880, 0.4196, 0.3265, 0.3289])]
199 | Eval: Epoch: 179 dataset: SUN397 ACC: 0.6191481872580078
200 | Eval: Epoch: 179 dataset: Cars ACC: 0.6034075363760726
201 | Eval: Epoch: 179 dataset: RESISC45 ACC: 0.7395238095238095
202 | Eval: Epoch: 179 dataset: EuroSAT ACC: 0.8144444444444444
203 | Eval: Epoch: 179 dataset: SVHN ACC: 0.8572141979102643
204 | Eval: Epoch: 179 dataset: GTSRB ACC: 0.7551860649247822
205 | Eval: Epoch: 179 dataset: MNIST ACC: 0.9722
206 | Eval: Epoch: 179 dataset: DTD ACC: 0.525
207 | Eval: Epoch: 179 Avg ACC:0.7357655300546726
208 |
209 | [tensor([1.0000, 0.3320, 0.2790, 0.3867, 0.4364, 0.3894, 0.4233, 0.3229, 0.3281])]
210 | Eval: Epoch: 189 dataset: SUN397 ACC: 0.6185950621008699
211 | Eval: Epoch: 189 dataset: Cars ACC: 0.6029100858102227
212 | Eval: Epoch: 189 dataset: RESISC45 ACC: 0.74
213 | Eval: Epoch: 189 dataset: EuroSAT ACC: 0.8166666666666667
214 | Eval: Epoch: 189 dataset: SVHN ACC: 0.8573678549477566
215 | Eval: Epoch: 189 dataset: GTSRB ACC: 0.7566112430720506
216 | Eval: Epoch: 189 dataset: MNIST ACC: 0.9717
217 | Eval: Epoch: 189 dataset: DTD ACC: 0.523936170212766
218 | Eval: Epoch: 189 Avg ACC:0.7359733853512915
219 |
220 | [tensor([1.0000, 0.3317, 0.2732, 0.3899, 0.4427, 0.3901, 0.4268, 0.3188, 0.3264])]
221 | Eval: Epoch: 199 dataset: SUN397 ACC: 0.6182430733645095
222 | Eval: Epoch: 199 dataset: Cars ACC: 0.60179082203706
223 | Eval: Epoch: 199 dataset: RESISC45 ACC: 0.7412698412698413
224 | Eval: Epoch: 199 dataset: EuroSAT ACC: 0.8218518518518518
225 | Eval: Epoch: 199 dataset: SVHN ACC: 0.8574446834665027
226 | Eval: Epoch: 199 dataset: GTSRB ACC: 0.7584323040380048
227 | Eval: Epoch: 199 dataset: MNIST ACC: 0.9713
228 | Eval: Epoch: 199 dataset: DTD ACC: 0.524468085106383
229 | Eval: Epoch: 199 Avg ACC:0.7368500826417692
230 |
231 | [tensor([1.0000, 0.3308, 0.2676, 0.3921, 0.4492, 0.3902, 0.4309, 0.3136, 0.3253])]
232 | Eval: Epoch: 209 dataset: SUN397 ACC: 0.6182933574697038
233 | Eval: Epoch: 209 dataset: Cars ACC: 0.6006715582638975
234 | Eval: Epoch: 209 dataset: RESISC45 ACC: 0.7420634920634921
235 | Eval: Epoch: 209 dataset: EuroSAT ACC: 0.8248148148148148
236 | Eval: Epoch: 209 dataset: SVHN ACC: 0.8571373693915181
237 | Eval: Epoch: 209 dataset: GTSRB ACC: 0.7599366587490103
238 | Eval: Epoch: 209 dataset: MNIST ACC: 0.9707
239 | Eval: Epoch: 209 dataset: DTD ACC: 0.5234042553191489
240 | Eval: Epoch: 209 Avg ACC:0.7371276882589483
241 |
242 | [tensor([1.0000, 0.3294, 0.2619, 0.3941, 0.4549, 0.3916, 0.4361, 0.3092, 0.3243])]
243 | Eval: Epoch: 219 dataset: SUN397 ACC: 0.6176899482073717
244 | Eval: Epoch: 219 dataset: Cars ACC: 0.59880611864196
245 | Eval: Epoch: 219 dataset: RESISC45 ACC: 0.7423809523809524
246 | Eval: Epoch: 219 dataset: EuroSAT ACC: 0.827037037037037
247 | Eval: Epoch: 219 dataset: SVHN ACC: 0.8574830977258758
248 | Eval: Epoch: 219 dataset: GTSRB ACC: 0.7624703087885986
249 | Eval: Epoch: 219 dataset: MNIST ACC: 0.9698
250 | Eval: Epoch: 219 dataset: DTD ACC: 0.523936170212766
251 | Eval: Epoch: 219 Avg ACC:0.7374504541243202
252 |
253 | [tensor([1.0000, 0.3291, 0.2572, 0.3962, 0.4600, 0.3936, 0.4420, 0.3061, 0.3259])]
254 | Eval: Epoch: 229 dataset: SUN397 ACC: 0.6171368230502339
255 | Eval: Epoch: 229 dataset: Cars ACC: 0.5971894043029474
256 | Eval: Epoch: 229 dataset: RESISC45 ACC: 0.7412698412698413
257 | Eval: Epoch: 229 dataset: EuroSAT ACC: 0.8274074074074074
258 | Eval: Epoch: 229 dataset: SVHN ACC: 0.8582513829133375
259 | Eval: Epoch: 229 dataset: GTSRB ACC: 0.7649247822644497
260 | Eval: Epoch: 229 dataset: MNIST ACC: 0.969
261 | Eval: Epoch: 229 dataset: DTD ACC: 0.5228723404255319
262 | Eval: Epoch: 229 Avg ACC:0.7372564977042186
263 |
264 | [tensor([1.0000, 0.3283, 0.2522, 0.3975, 0.4634, 0.3958, 0.4489, 0.3045, 0.3253])]
265 | Eval: Epoch: 239 dataset: SUN397 ACC: 0.6169356866294564
266 | Eval: Epoch: 239 dataset: Cars ACC: 0.5950752393980848
267 | Eval: Epoch: 239 dataset: RESISC45 ACC: 0.7404761904761905
268 | Eval: Epoch: 239 dataset: EuroSAT ACC: 0.8285185185185185
269 | Eval: Epoch: 239 dataset: SVHN ACC: 0.859519053472649
270 | Eval: Epoch: 239 dataset: GTSRB ACC: 0.7683293745051465
271 | Eval: Epoch: 239 dataset: MNIST ACC: 0.9686
272 | Eval: Epoch: 239 dataset: DTD ACC: 0.5218085106382979
273 | Eval: Epoch: 239 Avg ACC:0.737407821704793
274 |
275 | [tensor([1.0000, 0.3274, 0.2476, 0.3985, 0.4663, 0.3978, 0.4552, 0.3022, 0.3211])]
276 | Eval: Epoch: 249 dataset: SUN397 ACC: 0.6163825614723186
277 | Eval: Epoch: 249 dataset: Cars ACC: 0.5945777888322348
278 | Eval: Epoch: 249 dataset: RESISC45 ACC: 0.7403174603174603
279 | Eval: Epoch: 249 dataset: EuroSAT ACC: 0.8288888888888889
280 | Eval: Epoch: 249 dataset: SVHN ACC: 0.8600952673632453
281 | Eval: Epoch: 249 dataset: GTSRB ACC: 0.7714964370546318
282 | Eval: Epoch: 249 dataset: MNIST ACC: 0.9682
283 | Eval: Epoch: 249 dataset: DTD ACC: 0.5202127659574468
284 | Eval: Epoch: 249 Avg ACC:0.7375213962357782
285 |
286 | [tensor([1.0000, 0.3266, 0.2447, 0.3996, 0.4698, 0.4008, 0.4610, 0.3012, 0.3177])]
287 | Eval: Epoch: 259 dataset: SUN397 ACC: 0.6157791522099865
288 | Eval: Epoch: 259 dataset: Cars ACC: 0.5938316129834598
289 | Eval: Epoch: 259 dataset: RESISC45 ACC: 0.7392063492063492
290 | Eval: Epoch: 259 dataset: EuroSAT ACC: 0.8288888888888889
291 | Eval: Epoch: 259 dataset: SVHN ACC: 0.8614013521819299
292 | Eval: Epoch: 259 dataset: GTSRB ACC: 0.7749010292953286
293 | Eval: Epoch: 259 dataset: MNIST ACC: 0.9679
294 | Eval: Epoch: 259 dataset: DTD ACC: 0.5191489361702127
295 | Eval: Epoch: 259 Avg ACC:0.7376321651170195
296 |
297 | [tensor([1.0000, 0.3257, 0.2417, 0.4008, 0.4722, 0.4042, 0.4676, 0.3001, 0.3163])]
298 | Eval: Epoch: 269 dataset: SUN397 ACC: 0.6146729018957108
299 | Eval: Epoch: 269 dataset: Cars ACC: 0.5920905360029847
300 | Eval: Epoch: 269 dataset: RESISC45 ACC: 0.7384126984126984
301 | Eval: Epoch: 269 dataset: EuroSAT ACC: 0.8292592592592593
302 | Eval: Epoch: 269 dataset: SVHN ACC: 0.8626690227412416
303 | Eval: Epoch: 269 dataset: GTSRB ACC: 0.7787806809184481
304 | Eval: Epoch: 269 dataset: MNIST ACC: 0.9679
305 | Eval: Epoch: 269 dataset: DTD ACC: 0.5175531914893617
306 | Eval: Epoch: 269 Avg ACC:0.737667286339963
307 |
308 | [tensor([1.0000, 0.3249, 0.2389, 0.4023, 0.4755, 0.4077, 0.4746, 0.2987, 0.3152])]
309 | Eval: Epoch: 279 dataset: SUN397 ACC: 0.6138180721074068
310 | Eval: Epoch: 279 dataset: Cars ACC: 0.5894789205322721
311 | Eval: Epoch: 279 dataset: RESISC45 ACC: 0.7376190476190476
312 | Eval: Epoch: 279 dataset: EuroSAT ACC: 0.8288888888888889
313 | Eval: Epoch: 279 dataset: SVHN ACC: 0.8640519360786724
314 | Eval: Epoch: 279 dataset: GTSRB ACC: 0.7829770387965163
315 | Eval: Epoch: 279 dataset: MNIST ACC: 0.9678
316 | Eval: Epoch: 279 dataset: DTD ACC: 0.5164893617021277
317 | Eval: Epoch: 279 Avg ACC:0.7376404082156165
318 |
319 | [tensor([1.0000, 0.3232, 0.2341, 0.4022, 0.4776, 0.4107, 0.4808, 0.2955, 0.3127])]
320 | Eval: Epoch: 289 dataset: SUN397 ACC: 0.6135163674762407
321 | Eval: Epoch: 289 dataset: Cars ACC: 0.5883596567591096
322 | Eval: Epoch: 289 dataset: RESISC45 ACC: 0.7368253968253968
323 | Eval: Epoch: 289 dataset: EuroSAT ACC: 0.8274074074074074
324 | Eval: Epoch: 289 dataset: SVHN ACC: 0.865319606637984
325 | Eval: Epoch: 289 dataset: GTSRB ACC: 0.7855106888361045
326 | Eval: Epoch: 289 dataset: MNIST ACC: 0.9671
327 | Eval: Epoch: 289 dataset: DTD ACC: 0.5159574468085106
328 | Eval: Epoch: 289 Avg ACC:0.7374995713438441
329 |
330 | [tensor([1.0000, 0.3215, 0.2290, 0.4009, 0.4804, 0.4135, 0.4863, 0.2920, 0.3091])]
331 | Eval: Epoch: 299 dataset: SUN397 ACC: 0.6132146628450746
332 | Eval: Epoch: 299 dataset: Cars ACC: 0.587240392985947
333 | Eval: Epoch: 299 dataset: RESISC45 ACC: 0.7363492063492063
334 | Eval: Epoch: 299 dataset: EuroSAT ACC: 0.8281481481481482
335 | Eval: Epoch: 299 dataset: SVHN ACC: 0.8663952059004303
336 | Eval: Epoch: 299 dataset: GTSRB ACC: 0.7885193982581156
337 | Eval: Epoch: 299 dataset: MNIST ACC: 0.9668
338 | Eval: Epoch: 299 dataset: DTD ACC: 0.5154255319148936
339 | Eval: Epoch: 299 Avg ACC:0.737761568300227
340 |
341 | [tensor([1.0000, 0.3209, 0.2266, 0.3999, 0.4833, 0.4156, 0.4916, 0.2906, 0.3074])]
342 | Eval: Epoch: 309 dataset: SUN397 ACC: 0.6129632423191029
343 | Eval: Epoch: 309 dataset: Cars ACC: 0.586494217137172
344 | Eval: Epoch: 309 dataset: RESISC45 ACC: 0.734920634920635
345 | Eval: Epoch: 309 dataset: EuroSAT ACC: 0.8277777777777777
346 | Eval: Epoch: 309 dataset: SVHN ACC: 0.8671634910878918
347 | Eval: Epoch: 309 dataset: GTSRB ACC: 0.7908155186064925
348 | Eval: Epoch: 309 dataset: MNIST ACC: 0.9667
349 | Eval: Epoch: 309 dataset: DTD ACC: 0.5148936170212766
350 | Eval: Epoch: 309 Avg ACC:0.7377160623587936
351 |
352 | [tensor([1.0000, 0.3216, 0.2241, 0.4017, 0.4862, 0.4166, 0.4966, 0.2893, 0.3083])]
353 | Eval: Epoch: 319 dataset: SUN397 ACC: 0.6126615376879369
354 | Eval: Epoch: 319 dataset: Cars ACC: 0.5853749533640095
355 | Eval: Epoch: 319 dataset: RESISC45 ACC: 0.7347619047619047
356 | Eval: Epoch: 319 dataset: EuroSAT ACC: 0.8281481481481482
357 | Eval: Epoch: 319 dataset: SVHN ACC: 0.8670866625691457
358 | Eval: Epoch: 319 dataset: GTSRB ACC: 0.7926365795724466
359 | Eval: Epoch: 319 dataset: MNIST ACC: 0.9663
360 | Eval: Epoch: 319 dataset: DTD ACC: 0.5148936170212766
361 | Eval: Epoch: 319 Avg ACC:0.7377329253906085
362 |
363 | [tensor([1.0000, 0.3216, 0.2195, 0.4032, 0.4903, 0.4195, 0.5017, 0.2877, 0.3084])]
364 | Eval: Epoch: 329 dataset: SUN397 ACC: 0.6120581284256047
365 | Eval: Epoch: 329 dataset: Cars ACC: 0.5843800522323094
366 | Eval: Epoch: 329 dataset: RESISC45 ACC: 0.7347619047619047
367 | Eval: Epoch: 329 dataset: EuroSAT ACC: 0.8296296296296296
368 | Eval: Epoch: 329 dataset: SVHN ACC: 0.8681238475722188
369 | Eval: Epoch: 329 dataset: GTSRB ACC: 0.7945368171021377
370 | Eval: Epoch: 329 dataset: MNIST ACC: 0.9661
371 | Eval: Epoch: 329 dataset: DTD ACC: 0.5138297872340426
372 | Eval: Epoch: 329 Avg ACC:0.7379275208697309
373 |
374 | [tensor([1.0000, 0.3224, 0.2152, 0.4049, 0.4936, 0.4227, 0.5080, 0.2860, 0.3086])]
375 | Eval: Epoch: 339 dataset: SUN397 ACC: 0.6115552873736612
376 | Eval: Epoch: 339 dataset: Cars ACC: 0.5826389752518344
377 | Eval: Epoch: 339 dataset: RESISC45 ACC: 0.7336507936507937
378 | Eval: Epoch: 339 dataset: EuroSAT ACC: 0.8296296296296296
379 | Eval: Epoch: 339 dataset: SVHN ACC: 0.8690457897971727
380 | Eval: Epoch: 339 dataset: GTSRB ACC: 0.7966745843230404
381 | Eval: Epoch: 339 dataset: MNIST ACC: 0.9661
382 | Eval: Epoch: 339 dataset: DTD ACC: 0.5122340425531915
383 | Eval: Epoch: 339 Avg ACC:0.7376911378224155
384 |
385 | [tensor([1.0000, 0.3216, 0.2118, 0.4060, 0.4965, 0.4258, 0.5153, 0.2827, 0.3070])]
386 | Eval: Epoch: 349 dataset: SUN397 ACC: 0.6108010257957459
387 | Eval: Epoch: 349 dataset: Cars ACC: 0.5812709861957468
388 | Eval: Epoch: 349 dataset: RESISC45 ACC: 0.7322222222222222
389 | Eval: Epoch: 349 dataset: EuroSAT ACC: 0.8296296296296296
390 | Eval: Epoch: 349 dataset: SVHN ACC: 0.8699677320221266
391 | Eval: Epoch: 349 dataset: GTSRB ACC: 0.7996832937450514
392 | Eval: Epoch: 349 dataset: MNIST ACC: 0.9655
393 | Eval: Epoch: 349 dataset: DTD ACC: 0.5117021276595745
394 | Eval: Epoch: 349 Avg ACC:0.7375971271587621
395 |
396 | [tensor([1.0000, 0.3193, 0.2079, 0.4089, 0.4994, 0.4267, 0.5212, 0.2778, 0.3044])]
397 | Eval: Epoch: 359 dataset: SUN397 ACC: 0.6105496052697742
398 | Eval: Epoch: 359 dataset: Cars ACC: 0.5806491729884343
399 | Eval: Epoch: 359 dataset: RESISC45 ACC: 0.7331746031746031
400 | Eval: Epoch: 359 dataset: EuroSAT ACC: 0.8311111111111111
401 | Eval: Epoch: 359 dataset: SVHN ACC: 0.8700061462814997
402 | Eval: Epoch: 359 dataset: GTSRB ACC: 0.8027711797307997
403 | Eval: Epoch: 359 dataset: MNIST ACC: 0.9647
404 | Eval: Epoch: 359 dataset: DTD ACC: 0.5111702127659574
405 | Eval: Epoch: 359 Avg ACC:0.7380165039152725
406 |
407 | [tensor([1.0000, 0.3177, 0.2036, 0.4110, 0.5017, 0.4284, 0.5256, 0.2744, 0.3035])]
408 | Eval: Epoch: 369 dataset: SUN397 ACC: 0.6101976165334138
409 | Eval: Epoch: 369 dataset: Cars ACC: 0.5799029971396592
410 | Eval: Epoch: 369 dataset: RESISC45 ACC: 0.733015873015873
411 | Eval: Epoch: 369 dataset: EuroSAT ACC: 0.8314814814814815
412 | Eval: Epoch: 369 dataset: SVHN ACC: 0.8705055316533498
413 | Eval: Epoch: 369 dataset: GTSRB ACC: 0.8051464766429137
414 | Eval: Epoch: 369 dataset: MNIST ACC: 0.9642
415 | Eval: Epoch: 369 dataset: DTD ACC: 0.5101063829787233
416 | Eval: Epoch: 369 Avg ACC:0.7380695449306767
417 |
418 | [tensor([1.0000, 0.3179, 0.2009, 0.4134, 0.5041, 0.4309, 0.5308, 0.2733, 0.3036])]
419 | Eval: Epoch: 379 dataset: SUN397 ACC: 0.6091919344295268
420 | Eval: Epoch: 379 dataset: Cars ACC: 0.5796542718567342
421 | Eval: Epoch: 379 dataset: RESISC45 ACC: 0.7331746031746031
422 | Eval: Epoch: 379 dataset: EuroSAT ACC: 0.8322222222222222
423 | Eval: Epoch: 379 dataset: SVHN ACC: 0.8710049170251998
424 | Eval: Epoch: 379 dataset: GTSRB ACC: 0.806492478226445
425 | Eval: Epoch: 379 dataset: MNIST ACC: 0.9641
426 | Eval: Epoch: 379 dataset: DTD ACC: 0.5101063829787233
427 | Eval: Epoch: 379 Avg ACC:0.7382433512391817
428 |
429 | [tensor([1.0000, 0.3179, 0.1964, 0.4167, 0.5077, 0.4299, 0.5343, 0.2700, 0.3007])]
430 | Eval: Epoch: 389 dataset: SUN397 ACC: 0.609795343691859
431 | Eval: Epoch: 389 dataset: Cars ACC: 0.5782862828006466
432 | Eval: Epoch: 389 dataset: RESISC45 ACC: 0.7339682539682539
433 | Eval: Epoch: 389 dataset: EuroSAT ACC: 0.8340740740740741
434 | Eval: Epoch: 389 dataset: SVHN ACC: 0.870620774431469
435 | Eval: Epoch: 389 dataset: GTSRB ACC: 0.8079176563737134
436 | Eval: Epoch: 389 dataset: MNIST ACC: 0.9637
437 | Eval: Epoch: 389 dataset: DTD ACC: 0.5085106382978724
438 | Eval: Epoch: 389 Avg ACC:0.738359127954736
439 |
440 | [tensor([1.0000, 0.3180, 0.1922, 0.4197, 0.5098, 0.4293, 0.5374, 0.2653, 0.2978])]
441 | Eval: Epoch: 399 dataset: SUN397 ACC: 0.6103484688489969
442 | Eval: Epoch: 399 dataset: Cars ACC: 0.5777888322347967
443 | Eval: Epoch: 399 dataset: RESISC45 ACC: 0.7347619047619047
444 | Eval: Epoch: 399 dataset: EuroSAT ACC: 0.8348148148148148
445 | Eval: Epoch: 399 dataset: SVHN ACC: 0.870159803318992
446 | Eval: Epoch: 399 dataset: GTSRB ACC: 0.809501187648456
447 | Eval: Epoch: 399 dataset: MNIST ACC: 0.9628
448 | Eval: Epoch: 399 dataset: DTD ACC: 0.5074468085106383
449 | Eval: Epoch: 399 Avg ACC:0.7384527275173249
450 |
451 | [tensor([1.0000, 0.3186, 0.1893, 0.4205, 0.5113, 0.4312, 0.5423, 0.2636, 0.2945])]
452 | Eval: Epoch: 409 dataset: SUN397 ACC: 0.6102479006386081
453 | Eval: Epoch: 409 dataset: Cars ACC: 0.5770426563860216
454 | Eval: Epoch: 409 dataset: RESISC45 ACC: 0.7344444444444445
455 | Eval: Epoch: 409 dataset: EuroSAT ACC: 0.8348148148148148
456 | Eval: Epoch: 409 dataset: SVHN ACC: 0.870620774431469
457 | Eval: Epoch: 409 dataset: GTSRB ACC: 0.8119556611243072
458 | Eval: Epoch: 409 dataset: MNIST ACC: 0.9624
459 | Eval: Epoch: 409 dataset: DTD ACC: 0.5063829787234042
460 | Eval: Epoch: 409 Avg ACC:0.7384886538203836
461 |
462 | [tensor([1.0000, 0.3190, 0.1867, 0.4226, 0.5133, 0.4322, 0.5460, 0.2617, 0.2912])]
463 | Eval: Epoch: 419 dataset: SUN397 ACC: 0.6103484688489969
464 | Eval: Epoch: 419 dataset: Cars ACC: 0.5757990299713966
465 | Eval: Epoch: 419 dataset: RESISC45 ACC: 0.734920634920635
466 | Eval: Epoch: 419 dataset: EuroSAT ACC: 0.8348148148148148
467 | Eval: Epoch: 419 dataset: SVHN ACC: 0.8708512599877074
468 | Eval: Epoch: 419 dataset: GTSRB ACC: 0.8129057798891528
469 | Eval: Epoch: 419 dataset: MNIST ACC: 0.9621
470 | Eval: Epoch: 419 dataset: DTD ACC: 0.5042553191489362
471 | Eval: Epoch: 419 Avg ACC:0.738249413447705
472 |
473 | [tensor([1.0000, 0.3194, 0.1848, 0.4243, 0.5150, 0.4321, 0.5501, 0.2585, 0.2900])]
474 | Eval: Epoch: 429 dataset: SUN397 ACC: 0.6098456277970533
475 | Eval: Epoch: 429 dataset: Cars ACC: 0.5745554035567716
476 | Eval: Epoch: 429 dataset: RESISC45 ACC: 0.7350793650793651
477 | Eval: Epoch: 429 dataset: EuroSAT ACC: 0.8366666666666667
478 | Eval: Epoch: 429 dataset: SVHN ACC: 0.8705823601720959
479 | Eval: Epoch: 429 dataset: GTSRB ACC: 0.8142517814726841
480 | Eval: Epoch: 429 dataset: MNIST ACC: 0.9612
481 | Eval: Epoch: 429 dataset: DTD ACC: 0.5042553191489362
482 | Eval: Epoch: 429 Avg ACC:0.7383045654866965
483 |
484 | [tensor([1.0000, 0.3198, 0.1828, 0.4259, 0.5156, 0.4334, 0.5541, 0.2581, 0.2910])]
485 | Eval: Epoch: 439 dataset: SUN397 ACC: 0.6091919344295268
486 | Eval: Epoch: 439 dataset: Cars ACC: 0.5740579529909215
487 | Eval: Epoch: 439 dataset: RESISC45 ACC: 0.7350793650793651
488 | Eval: Epoch: 439 dataset: EuroSAT ACC: 0.8366666666666667
489 | Eval: Epoch: 439 dataset: SVHN ACC: 0.8709665027658267
490 | Eval: Epoch: 439 dataset: GTSRB ACC: 0.8155977830562153
491 | Eval: Epoch: 439 dataset: MNIST ACC: 0.9611
492 | Eval: Epoch: 439 dataset: DTD ACC: 0.5037234042553191
493 | Eval: Epoch: 439 Avg ACC:0.73829795115548
494 |
495 | [tensor([1.0000, 0.3201, 0.1804, 0.4269, 0.5161, 0.4341, 0.5577, 0.2541, 0.2909])]
496 | Eval: Epoch: 449 dataset: SUN397 ACC: 0.6094936390606929
497 | Eval: Epoch: 449 dataset: Cars ACC: 0.5733117771421465
498 | Eval: Epoch: 449 dataset: RESISC45 ACC: 0.7353968253968254
499 | Eval: Epoch: 449 dataset: EuroSAT ACC: 0.837037037037037
500 | Eval: Epoch: 449 dataset: SVHN ACC: 0.8711969883220652
501 | Eval: Epoch: 449 dataset: GTSRB ACC: 0.8166270783847981
502 | Eval: Epoch: 449 dataset: MNIST ACC: 0.96
503 | Eval: Epoch: 449 dataset: DTD ACC: 0.5037234042553191
504 | Eval: Epoch: 449 Avg ACC:0.7383483436998605
505 |
506 | [tensor([1.0000, 0.3190, 0.1755, 0.4256, 0.5159, 0.4341, 0.5609, 0.2507, 0.2898])]
507 | Eval: Epoch: 459 dataset: SUN397 ACC: 0.6099964801126364
508 | Eval: Epoch: 459 dataset: Cars ACC: 0.5723168760104465
509 | Eval: Epoch: 459 dataset: RESISC45 ACC: 0.7353968253968254
510 | Eval: Epoch: 459 dataset: EuroSAT ACC: 0.8362962962962963
511 | Eval: Epoch: 459 dataset: SVHN ACC: 0.8712354025814383
512 | Eval: Epoch: 459 dataset: GTSRB ACC: 0.8187648456057007
513 | Eval: Epoch: 459 dataset: MNIST ACC: 0.959
514 | Eval: Epoch: 459 dataset: DTD ACC: 0.5037234042553191
515 | Eval: Epoch: 459 Avg ACC:0.7383412662823328
516 |
517 | [tensor([1.0000, 0.3193, 0.1727, 0.4249, 0.5163, 0.4341, 0.5663, 0.2486, 0.2893])]
518 | Eval: Epoch: 469 dataset: SUN397 ACC: 0.609946196007442
519 | Eval: Epoch: 469 dataset: Cars ACC: 0.5714463375202089
520 | Eval: Epoch: 469 dataset: RESISC45 ACC: 0.7352380952380952
521 | Eval: Epoch: 469 dataset: EuroSAT ACC: 0.8351851851851851
522 | Eval: Epoch: 469 dataset: SVHN ACC: 0.8711969883220652
523 | Eval: Epoch: 469 dataset: GTSRB ACC: 0.8205067300079176
524 | Eval: Epoch: 469 dataset: MNIST ACC: 0.9586
525 | Eval: Epoch: 469 dataset: DTD ACC: 0.5037234042553191
526 | Eval: Epoch: 469 Avg ACC:0.7382303670670292
527 |
528 | [tensor([1.0000, 0.3207, 0.1719, 0.4247, 0.5183, 0.4351, 0.5730, 0.2475, 0.2902])]
529 | Eval: Epoch: 479 dataset: SUN397 ACC: 0.6091919344295268
530 | Eval: Epoch: 479 dataset: Cars ACC: 0.5700783484641214
531 | Eval: Epoch: 479 dataset: RESISC45 ACC: 0.7333333333333333
532 | Eval: Epoch: 479 dataset: EuroSAT ACC: 0.835925925925926
533 | Eval: Epoch: 479 dataset: SVHN ACC: 0.8717347879532883
534 | Eval: Epoch: 479 dataset: GTSRB ACC: 0.8224069675376089
535 | Eval: Epoch: 479 dataset: MNIST ACC: 0.9582
536 | Eval: Epoch: 479 dataset: DTD ACC: 0.5031914893617021
537 | Eval: Epoch: 479 Avg ACC:0.7380078483756884
538 |
539 | [tensor([1.0000, 0.3207, 0.1718, 0.4252, 0.5196, 0.4364, 0.5769, 0.2461, 0.2894])]
540 | Eval: Epoch: 489 dataset: SUN397 ACC: 0.6088399456931664
541 | Eval: Epoch: 489 dataset: Cars ACC: 0.5693321726153463
542 | Eval: Epoch: 489 dataset: RESISC45 ACC: 0.733015873015873
543 | Eval: Epoch: 489 dataset: EuroSAT ACC: 0.8362962962962963
544 | Eval: Epoch: 489 dataset: SVHN ACC: 0.8721189305470191
545 | Eval: Epoch: 489 dataset: GTSRB ACC: 0.8234362628661916
546 | Eval: Epoch: 489 dataset: MNIST ACC: 0.9578
547 | Eval: Epoch: 489 dataset: DTD ACC: 0.502127659574468
548 | Eval: Epoch: 489 Avg ACC:0.737870892576045
549 |
550 | [tensor([1.0000, 0.3171, 0.1698, 0.4235, 0.5198, 0.4386, 0.5803, 0.2452, 0.2885])]
551 | Eval: Epoch: 499 dataset: SUN397 ACC: 0.6081862523256398
552 | Eval: Epoch: 499 dataset: Cars ACC: 0.5690834473324213
553 | Eval: Epoch: 499 dataset: RESISC45 ACC: 0.7317460317460317
554 | Eval: Epoch: 499 dataset: EuroSAT ACC: 0.8348148148148148
555 | Eval: Epoch: 499 dataset: SVHN ACC: 0.8730792870313461
556 | Eval: Epoch: 499 dataset: GTSRB ACC: 0.8247822644497229
557 | Eval: Epoch: 499 dataset: MNIST ACC: 0.9577
558 | Eval: Epoch: 499 dataset: DTD ACC: 0.5015957446808511
559 | Eval: Epoch: 499 Avg ACC:0.7376234802976035
560 |
--------------------------------------------------------------------------------
/src/args.py:
--------------------------------------------------------------------------------
1 | import os
2 | import argparse
3 |
4 | import torch
5 |
6 | def parse_arguments():
7 | parser = argparse.ArgumentParser()
8 | parser.add_argument(
9 | "--data-location",
10 | type=str,
11 | default=os.path.expanduser('~/data'),
12 | help="The root directory for the datasets.",
13 | )
14 | parser.add_argument(
15 | "--eval-datasets",
16 | default=None,
17 | type=lambda x: x.split(","),
18 | help="Which datasets to use for evaluation. Split by comma, e.g. MNIST,EuroSAT. "
19 | )
20 | parser.add_argument(
21 | "--train-dataset",
22 | default=None,
23 | type=lambda x: x.split(","),
24 | help="Which dataset(s) to patch on.",
25 | )
26 | parser.add_argument(
27 | "--exp_name",
28 | type=str,
29 | default=None,
30 | help="Name of the experiment, for organization purposes only."
31 | )
32 | parser.add_argument(
33 | "--results-db",
34 | type=str,
35 | default=None,
36 | help="Where to store the results, else does not store",
37 | )
38 | parser.add_argument(
39 | "--model",
40 | type=str,
41 | default=None,
42 | help="The type of model (e.g. RN50, ViT-B-32).",
43 | )
44 | parser.add_argument(
45 | "--batch-size",
46 | type=int,
47 | default=128,
48 | )
49 | parser.add_argument(
50 | "--lr",
51 | type=float,
52 | default=0.001,
53 | help="Learning rate."
54 | )
55 | parser.add_argument(
56 | "--wd",
57 | type=float,
58 | default=0.1,
59 | help="Weight decay"
60 | )
61 | parser.add_argument(
62 | "--ls",
63 | type=float,
64 | default=0.0,
65 | help="Label smoothing."
66 | )
67 | parser.add_argument(
68 | "--warmup_length",
69 | type=int,
70 | default=500,
71 | )
72 | parser.add_argument(
73 | "--epochs",
74 | type=int,
75 | default=10,
76 | )
77 | parser.add_argument(
78 | "--load",
79 | type=lambda x: x.split(","),
80 | default=None,
81 | help="Optionally load _classifiers_, e.g. a zero shot classifier or probe or ensemble both.",
82 | )
83 | parser.add_argument(
84 | "--save",
85 | type=str,
86 | default=None,
87 | help="Optionally save a _classifier_, e.g. a zero shot classifier or probe.",
88 | )
89 | parser.add_argument(
90 | "--cache-dir",
91 | type=str,
92 | default=None,
93 | help="Directory for caching features and encoder",
94 | )
95 | parser.add_argument(
96 | "--openclip-cachedir",
97 | type=str,
98 | default='/gscratch/efml/gamaga/.cache/open_clip',
99 | help='Directory for caching models from OpenCLIP'
100 | )
101 | parsed_args = parser.parse_args()
102 | parsed_args.device = "cuda" if torch.cuda.is_available() else "cpu"
103 |
104 | if parsed_args.load is not None and len(parsed_args.load) == 1:
105 | parsed_args.load = parsed_args.load[0]
106 | return parsed_args
107 |
--------------------------------------------------------------------------------
/src/datasets/cars.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import torchvision.datasets as datasets
4 |
5 |
6 | import pathlib
7 | from typing import Callable, Optional, Any, Tuple
8 |
9 | from PIL import Image
10 |
11 | from torchvision.datasets.utils import download_and_extract_archive, download_url, verify_str_arg
12 | from torchvision.datasets.vision import VisionDataset
13 |
14 |
15 | class PytorchStanfordCars(VisionDataset):
16 | """`Stanford Cars `_ Dataset
17 |
18 | The Cars dataset contains 16,185 images of 196 classes of cars. The data is
19 | split into 8,144 training images and 8,041 testing images, where each class
20 | has been split roughly in a 50-50 split
21 |
22 | .. note::
23 |
24 | This class needs `scipy `_ to load target files from `.mat` format.
25 |
26 | Args:
27 | root (string): Root directory of dataset
28 | split (string, optional): The dataset split, supports ``"train"`` (default) or ``"test"``.
29 | transform (callable, optional): A function/transform that takes in an PIL image
30 | and returns a transformed version. E.g, ``transforms.RandomCrop``
31 | target_transform (callable, optional): A function/transform that takes in the
32 | target and transforms it.
33 | download (bool, optional): If True, downloads the dataset from the internet and
34 | puts it in root directory. If dataset is already downloaded, it is not
35 | downloaded again."""
36 |
37 | def __init__(
38 | self,
39 | root: str,
40 | split: str = "train",
41 | transform: Optional[Callable] = None,
42 | target_transform: Optional[Callable] = None,
43 | download: bool = False,
44 | ) -> None:
45 |
46 | try:
47 | import scipy.io as sio
48 | except ImportError:
49 | raise RuntimeError("Scipy is not found. This dataset needs to have scipy installed: pip install scipy")
50 |
51 | super().__init__(root, transform=transform, target_transform=target_transform)
52 |
53 | self._split = verify_str_arg(split, "split", ("train", "test"))
54 | self._base_folder = pathlib.Path(root) / "stanford_cars"
55 | devkit = self._base_folder / "devkit"
56 |
57 | if self._split == "train":
58 | self._annotations_mat_path = devkit / "cars_train_annos.mat"
59 | self._images_base_path = self._base_folder / "cars_train"
60 | else:
61 | self._annotations_mat_path = devkit / "cars_test_annos_withlabels.mat"
62 | self._images_base_path = self._base_folder / "cars_test"
63 |
64 | if download:
65 | self.download()
66 |
67 | if not self._check_exists():
68 | raise RuntimeError("Dataset not found. You can use download=True to download it")
69 |
70 | self._samples = [
71 | (
72 | str(self._images_base_path / annotation["fname"]),
73 | annotation["class"] - 1, # Original target mapping starts from 1, hence -1
74 | )
75 | for annotation in sio.loadmat(self._annotations_mat_path, squeeze_me=True)["annotations"]
76 | ]
77 |
78 | self.classes = sio.loadmat(str(devkit / "cars_meta.mat"), squeeze_me=True)["class_names"].tolist()
79 | self.class_to_idx = {cls: i for i, cls in enumerate(self.classes)}
80 |
81 | def __len__(self) -> int:
82 | return len(self._samples)
83 |
84 | def __getitem__(self, idx: int) -> Tuple[Any, Any]:
85 | """Returns pil_image and class_id for given index"""
86 | image_path, target = self._samples[idx]
87 | pil_image = Image.open(image_path).convert("RGB")
88 |
89 | if self.transform is not None:
90 | pil_image = self.transform(pil_image)
91 | if self.target_transform is not None:
92 | target = self.target_transform(target)
93 | return pil_image, target
94 |
95 |
96 | def download(self) -> None:
97 | if self._check_exists():
98 | return
99 |
100 | download_and_extract_archive(
101 | url="https://ai.stanford.edu/~jkrause/cars/car_devkit.tgz",
102 | download_root=str(self._base_folder),
103 | md5="c3b158d763b6e2245038c8ad08e45376",
104 | )
105 | if self._split == "train":
106 | download_and_extract_archive(
107 | url="https://ai.stanford.edu/~jkrause/car196/cars_train.tgz",
108 | download_root=str(self._base_folder),
109 | md5="065e5b463ae28d29e77c1b4b166cfe61",
110 | )
111 | else:
112 | download_and_extract_archive(
113 | url="https://ai.stanford.edu/~jkrause/car196/cars_test.tgz",
114 | download_root=str(self._base_folder),
115 | md5="4ce7ebf6a94d07f1952d94dd34c4d501",
116 | )
117 | download_url(
118 | url="https://ai.stanford.edu/~jkrause/car196/cars_test_annos_withlabels.mat",
119 | root=str(self._base_folder),
120 | md5="b0a2b23655a3edd16d84508592a98d10",
121 | )
122 |
123 | def _check_exists(self) -> bool:
124 | if not (self._base_folder / "devkit").is_dir():
125 | return False
126 |
127 | return self._annotations_mat_path.exists() and self._images_base_path.is_dir()
128 |
129 |
130 | class Cars:
131 | def __init__(self,
132 | preprocess,
133 | location=os.path.expanduser('~/data'),
134 | batch_size=32,
135 | num_workers=0):
136 | # Data loading code
137 |
138 | self.train_dataset = PytorchStanfordCars(location, 'train', preprocess, download=False)
139 | self.train_loader = torch.utils.data.DataLoader(
140 | self.train_dataset,
141 | shuffle=True,
142 | batch_size=batch_size,
143 | num_workers=num_workers,
144 | )
145 |
146 | self.test_dataset = PytorchStanfordCars(location, 'test', preprocess, download=False)
147 | self.test_loader = torch.utils.data.DataLoader(
148 | self.test_dataset,
149 | batch_size=batch_size,
150 | num_workers=num_workers
151 | )
152 | self.test_loader_shuffle = torch.utils.data.DataLoader(
153 | self.test_dataset,
154 | shuffle=True,
155 | batch_size=batch_size,
156 | num_workers=num_workers
157 | )
158 | idx_to_class = dict((v, k) for k, v in self.train_dataset.class_to_idx.items())
159 | self.classnames = [idx_to_class[i].replace(
160 | '_', ' ') for i in range(len(idx_to_class))]
161 |
--------------------------------------------------------------------------------
/src/datasets/common.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import json
4 | import glob
5 | import collections
6 | import random
7 |
8 | import numpy as np
9 |
10 | from tqdm import tqdm
11 |
12 | import torchvision.datasets as datasets
13 | from torch.utils.data import Dataset, DataLoader, Sampler
14 |
15 |
16 | class SubsetSampler(Sampler):
17 | def __init__(self, indices):
18 | self.indices = indices
19 |
20 | def __iter__(self):
21 | return (i for i in self.indices)
22 |
23 | def __len__(self):
24 | return len(self.indices)
25 |
26 | class ImageFolderWithPaths(datasets.ImageFolder):
27 | def __init__(self, path, transform, flip_label_prob=0.0):
28 | super().__init__(path, transform)
29 | self.flip_label_prob = flip_label_prob
30 | if self.flip_label_prob > 0:
31 | print(f'Flipping labels with probability {self.flip_label_prob}')
32 | num_classes = len(self.classes)
33 | for i in range(len(self.samples)):
34 | if random.random() < self.flip_label_prob:
35 | new_label = random.randint(0, num_classes-1)
36 | self.samples[i] = (
37 | self.samples[i][0],
38 | new_label
39 | )
40 |
41 | def __getitem__(self, index):
42 | image, label = super(ImageFolderWithPaths, self).__getitem__(index)
43 | return {
44 | 'images': image,
45 | 'labels': label,
46 | 'image_paths': self.samples[index][0]
47 | }
48 |
49 |
50 | def maybe_dictionarize(batch):
51 | if isinstance(batch, dict):
52 | return batch
53 |
54 | if len(batch) == 2:
55 | batch = {'images': batch[0], 'labels': batch[1]}
56 | elif len(batch) == 3:
57 | batch = {'images': batch[0], 'labels': batch[1], 'metadata': batch[2]}
58 | else:
59 | raise ValueError(f'Unexpected number of elements: {len(batch)}')
60 |
61 | return batch
62 |
63 |
64 | def get_features_helper(image_encoder, dataloader, device):
65 | all_data = collections.defaultdict(list)
66 |
67 | image_encoder = image_encoder.to(device)
68 | image_encoder = torch.nn.DataParallel(image_encoder, device_ids=[x for x in range(torch.cuda.device_count())])
69 | image_encoder.eval()
70 |
71 | with torch.no_grad():
72 | for batch in tqdm(dataloader):
73 | batch = maybe_dictionarize(batch)
74 | features = image_encoder(batch['images'].cuda())
75 |
76 | all_data['features'].append(features.cpu())
77 |
78 | for key, val in batch.items():
79 | if key == 'images':
80 | continue
81 | if hasattr(val, 'cpu'):
82 | val = val.cpu()
83 | all_data[key].append(val)
84 | else:
85 | all_data[key].extend(val)
86 |
87 | for key, val in all_data.items():
88 | if torch.is_tensor(val[0]):
89 | all_data[key] = torch.cat(val).numpy()
90 |
91 | return all_data
92 |
93 |
94 | def get_features(is_train, image_encoder, dataset, device):
95 | split = 'train' if is_train else 'val'
96 | dname = type(dataset).__name__
97 | if image_encoder.cache_dir is not None:
98 | cache_dir = f'{image_encoder.cache_dir}/{dname}/{split}'
99 | cached_files = glob.glob(f'{cache_dir}/*')
100 | if image_encoder.cache_dir is not None and len(cached_files) > 0:
101 | print(f'Getting features from {cache_dir}')
102 | data = {}
103 | for cached_file in cached_files:
104 | name = os.path.splitext(os.path.basename(cached_file))[0]
105 | data[name] = torch.load(cached_file)
106 | else:
107 | print(f'Did not find cached features at {cache_dir}. Building from scratch.')
108 | loader = dataset.train_loader if is_train else dataset.test_loader
109 | data = get_features_helper(image_encoder, loader, device)
110 | if image_encoder.cache_dir is None:
111 | print('Not caching because no cache directory was passed.')
112 | else:
113 | os.makedirs(cache_dir, exist_ok=True)
114 | print(f'Caching data at {cache_dir}')
115 | for name, val in data.items():
116 | torch.save(val, f'{cache_dir}/{name}.pt')
117 | return data
118 |
119 |
120 | class FeatureDataset(Dataset):
121 | def __init__(self, is_train, image_encoder, dataset, device):
122 | self.data = get_features(is_train, image_encoder, dataset, device)
123 |
124 | def __len__(self):
125 | return len(self.data['features'])
126 |
127 | def __getitem__(self, idx):
128 | data = {k: v[idx] for k, v in self.data.items()}
129 | data['features'] = torch.from_numpy(data['features']).float()
130 | return data
131 |
132 |
133 | def get_dataloader(dataset, is_train, args, image_encoder=None):
134 | if image_encoder is not None:
135 | feature_dataset = FeatureDataset(is_train, image_encoder, dataset, args.device)
136 | dataloader = DataLoader(feature_dataset, batch_size=args.batch_size, shuffle=is_train)
137 | else:
138 | dataloader = dataset.train_loader if is_train else dataset.test_loader
139 | return dataloader
140 |
141 | def get_dataloader_shuffle(dataset):
142 | dataloader = dataset.test_loader_shuffle
143 | return dataloader
--------------------------------------------------------------------------------
/src/datasets/dtd.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import torchvision.datasets as datasets
4 |
5 |
6 | class DTD:
7 | def __init__(self,
8 | preprocess,
9 | location=os.path.expanduser('~/data'),
10 | batch_size=32,
11 | num_workers=0):
12 | # Data loading code
13 | traindir = os.path.join(location, 'dtd', 'train')
14 | valdir = os.path.join(location, 'dtd', 'test')
15 |
16 | self.train_dataset = datasets.ImageFolder(
17 | traindir, transform=preprocess)
18 | self.train_loader = torch.utils.data.DataLoader(
19 | self.train_dataset,
20 | shuffle=True,
21 | batch_size=batch_size,
22 | num_workers=num_workers,
23 | )
24 |
25 | self.test_dataset = datasets.ImageFolder(valdir, transform=preprocess)
26 | self.test_loader = torch.utils.data.DataLoader(
27 | self.test_dataset,
28 | batch_size=batch_size,
29 | num_workers=num_workers
30 | )
31 |
32 | self.test_loader_shuffle = torch.utils.data.DataLoader(
33 | self.test_dataset,
34 | shuffle=True,
35 | batch_size=batch_size,
36 | num_workers=num_workers
37 | )
38 |
39 | idx_to_class = dict((v, k)
40 | for k, v in self.train_dataset.class_to_idx.items())
41 | self.classnames = [idx_to_class[i].replace(
42 | '_', ' ') for i in range(len(idx_to_class))]
--------------------------------------------------------------------------------
/src/datasets/eurosat.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import torchvision.datasets as datasets
4 | import re
5 |
6 | def pretify_classname(classname):
7 | l = re.findall(r'[A-Z](?:[a-z]+|[A-Z]*(?=[A-Z]|$))', classname)
8 | l = [i.lower() for i in l]
9 | out = ' '.join(l)
10 | if out.endswith('al'):
11 | return out + ' area'
12 | return out
13 |
14 | class EuroSATBase:
15 | def __init__(self,
16 | preprocess,
17 | test_split,
18 | location='~/data',
19 | batch_size=32,
20 | num_workers=0):
21 | # Data loading code
22 | traindir = os.path.join(location, 'EuroSAT_splits', 'train')
23 | testdir = os.path.join(location, 'EuroSAT_splits', test_split)
24 |
25 |
26 | self.train_dataset = datasets.ImageFolder(traindir, transform=preprocess)
27 | self.train_loader = torch.utils.data.DataLoader(
28 | self.train_dataset,
29 | shuffle=True,
30 | batch_size=batch_size,
31 | num_workers=num_workers,
32 | )
33 |
34 | self.test_dataset = datasets.ImageFolder(testdir, transform=preprocess)
35 | self.test_loader = torch.utils.data.DataLoader(
36 | self.test_dataset,
37 | batch_size=batch_size,
38 | num_workers=num_workers
39 | )
40 | self.test_loader_shuffle = torch.utils.data.DataLoader(
41 | self.test_dataset,
42 | shuffle=True,
43 | batch_size=batch_size,
44 | num_workers=num_workers
45 | )
46 | idx_to_class = dict((v, k)
47 | for k, v in self.train_dataset.class_to_idx.items())
48 |
49 | self.classnames = [idx_to_class[i].replace('_', ' ') for i in range(len(idx_to_class))]
50 | self.classnames = [pretify_classname(c) for c in self.classnames]
51 | ours_to_open_ai = {
52 | 'annual crop': 'annual crop land',
53 | 'forest': 'forest',
54 | 'herbaceous vegetation': 'brushland or shrubland',
55 | 'highway': 'highway or road',
56 | 'industrial area': 'industrial buildings or commercial buildings',
57 | 'pasture': 'pasture land',
58 | 'permanent crop': 'permanent crop land',
59 | 'residential area': 'residential buildings or homes or apartments',
60 | 'river': 'river',
61 | 'sea lake': 'lake or sea',
62 | }
63 | for i in range(len(self.classnames)):
64 | self.classnames[i] = ours_to_open_ai[self.classnames[i]]
65 |
66 |
67 | class EuroSAT(EuroSATBase):
68 | def __init__(self,
69 | preprocess,
70 | location='~/data',
71 | batch_size=32,
72 | num_workers=0):
73 | super().__init__(preprocess, 'test', location, batch_size, num_workers)
74 |
75 |
76 | class EuroSATVal(EuroSATBase):
77 | def __init__(self,
78 | preprocess,
79 | location='~/data',
80 | batch_size=32,
81 | num_workers=0):
82 | super().__init__(preprocess, 'val', location, batch_size, num_workers)
83 |
--------------------------------------------------------------------------------
/src/datasets/gtsrb.py:
--------------------------------------------------------------------------------
1 | import csv
2 | import os
3 | import pathlib
4 | from typing import Any, Callable, Dict, List, Optional, Tuple
5 |
6 | import numpy as np
7 | import PIL
8 | import torch
9 | from torchvision.datasets.folder import make_dataset
10 | from torchvision.datasets.utils import (download_and_extract_archive, verify_str_arg)
11 | from torchvision.datasets.vision import VisionDataset
12 |
13 | def find_classes(directory: str) -> Tuple[List[str], Dict[str, int]]:
14 | """Finds the class folders in a dataset.
15 |
16 | See :class:`DatasetFolder` for details.
17 | """
18 | classes = sorted(entry.name for entry in os.scandir(directory) if entry.is_dir())
19 | if not classes:
20 | raise FileNotFoundError(f"Couldn't find any class folder in {directory}.")
21 |
22 | class_to_idx = {cls_name: i for i, cls_name in enumerate(classes)}
23 | return classes, class_to_idx
24 |
25 | class PyTorchGTSRB(VisionDataset):
26 | """`German Traffic Sign Recognition Benchmark (GTSRB) `_ Dataset.
27 |
28 | Modified from https://pytorch.org/vision/main/_modules/torchvision/datasets/gtsrb.html#GTSRB.
29 |
30 | Args:
31 | root (string): Root directory of the dataset.
32 | split (string, optional): The dataset split, supports ``"train"`` (default), or ``"test"``.
33 | transform (callable, optional): A function/transform that takes in an PIL image and returns a transformed
34 | version. E.g, ``transforms.RandomCrop``.
35 | target_transform (callable, optional): A function/transform that takes in the target and transforms it.
36 | download (bool, optional): If True, downloads the dataset from the internet and
37 | puts it in root directory. If dataset is already downloaded, it is not
38 | downloaded again.
39 | """
40 |
41 | def __init__(
42 | self,
43 | root: str,
44 | split: str = "train",
45 | transform: Optional[Callable] = None,
46 | target_transform: Optional[Callable] = None,
47 | download: bool = False,
48 | ) -> None:
49 |
50 | super().__init__(root, transform=transform, target_transform=target_transform)
51 |
52 | self._split = verify_str_arg(split, "split", ("train", "test"))
53 | self._base_folder = pathlib.Path(root) / "gtsrb"
54 | self._target_folder = (
55 | self._base_folder / "GTSRB" / ("Training" if self._split == "train" else "Final_Test/Images")
56 | )
57 |
58 | if download:
59 | self.download()
60 |
61 | if not self._check_exists():
62 | raise RuntimeError("Dataset not found. You can use download=True to download it")
63 |
64 | if self._split == "train":
65 | _, class_to_idx = find_classes(str(self._target_folder))
66 | samples = make_dataset(str(self._target_folder), extensions=(".ppm",), class_to_idx=class_to_idx)
67 | else:
68 | with open(self._base_folder / "GT-final_test.csv") as csv_file:
69 | samples = [
70 | (str(self._target_folder / row["Filename"]), int(row["ClassId"]))
71 | for row in csv.DictReader(csv_file, delimiter=";", skipinitialspace=True)
72 | ]
73 |
74 | self._samples = samples
75 | self.transform = transform
76 | self.target_transform = target_transform
77 |
78 | def __len__(self) -> int:
79 | return len(self._samples)
80 |
81 | def __getitem__(self, index: int) -> Tuple[Any, Any]:
82 |
83 | path, target = self._samples[index]
84 | sample = PIL.Image.open(path).convert("RGB")
85 |
86 | if self.transform is not None:
87 | sample = self.transform(sample)
88 |
89 | if self.target_transform is not None:
90 | target = self.target_transform(target)
91 |
92 | return sample, target
93 |
94 |
95 | def _check_exists(self) -> bool:
96 | return self._target_folder.is_dir()
97 |
98 | def download(self) -> None:
99 | if self._check_exists():
100 | return
101 |
102 | base_url = "https://sid.erda.dk/public/archives/daaeac0d7ce1152aea9b61d9f1e19370/"
103 |
104 | if self._split == "train":
105 | download_and_extract_archive(
106 | f"{base_url}GTSRB-Training_fixed.zip",
107 | download_root=str(self._base_folder),
108 | md5="513f3c79a4c5141765e10e952eaa2478",
109 | )
110 | else:
111 | download_and_extract_archive(
112 | f"{base_url}GTSRB_Final_Test_Images.zip",
113 | download_root=str(self._base_folder),
114 | md5="c7e4e6327067d32654124b0fe9e82185",
115 | )
116 | download_and_extract_archive(
117 | f"{base_url}GTSRB_Final_Test_GT.zip",
118 | download_root=str(self._base_folder),
119 | md5="fe31e9c9270bbcd7b84b7f21a9d9d9e5",
120 | )
121 |
122 |
123 | class GTSRB:
124 | def __init__(self,
125 | preprocess,
126 | location=os.path.expanduser('~/data'),
127 | batch_size=128,
128 | num_workers=0):
129 |
130 | # to fit with repo conventions for location
131 | self.train_dataset = PyTorchGTSRB(
132 | root=location,
133 | download=False,
134 | split='train',
135 | transform=preprocess
136 | )
137 |
138 | self.train_loader = torch.utils.data.DataLoader(
139 | self.train_dataset,
140 | batch_size=batch_size,
141 | shuffle=True,
142 | num_workers=num_workers
143 | )
144 |
145 | self.test_dataset = PyTorchGTSRB(
146 | root=location,
147 | download=False,
148 | split='test',
149 | transform=preprocess
150 | )
151 |
152 | self.test_loader = torch.utils.data.DataLoader(
153 | self.test_dataset,
154 | batch_size=batch_size,
155 | shuffle=False,
156 | num_workers=num_workers
157 | )
158 |
159 | self.test_loader_shuffle = torch.utils.data.DataLoader(
160 | self.test_dataset,
161 | batch_size=batch_size,
162 | shuffle=True,
163 | num_workers=num_workers
164 | )
165 |
166 | # from https://github.com/openai/CLIP/blob/e184f608c5d5e58165682f7c332c3a8b4c1545f2/data/prompts.md
167 | self.classnames = [
168 | 'red and white circle 20 kph speed limit',
169 | 'red and white circle 30 kph speed limit',
170 | 'red and white circle 50 kph speed limit',
171 | 'red and white circle 60 kph speed limit',
172 | 'red and white circle 70 kph speed limit',
173 | 'red and white circle 80 kph speed limit',
174 | 'end / de-restriction of 80 kph speed limit',
175 | 'red and white circle 100 kph speed limit',
176 | 'red and white circle 120 kph speed limit',
177 | 'red and white circle red car and black car no passing',
178 | 'red and white circle red truck and black car no passing',
179 | 'red and white triangle road intersection warning',
180 | 'white and yellow diamond priority road',
181 | 'red and white upside down triangle yield right-of-way',
182 | 'stop',
183 | 'empty red and white circle',
184 | 'red and white circle no truck entry',
185 | 'red circle with white horizonal stripe no entry',
186 | 'red and white triangle with exclamation mark warning',
187 | 'red and white triangle with black left curve approaching warning',
188 | 'red and white triangle with black right curve approaching warning',
189 | 'red and white triangle with black double curve approaching warning',
190 | 'red and white triangle rough / bumpy road warning',
191 | 'red and white triangle car skidding / slipping warning',
192 | 'red and white triangle with merging / narrow lanes warning',
193 | 'red and white triangle with person digging / construction / road work warning',
194 | 'red and white triangle with traffic light approaching warning',
195 | 'red and white triangle with person walking warning',
196 | 'red and white triangle with child and person walking warning',
197 | 'red and white triangle with bicyle warning',
198 | 'red and white triangle with snowflake / ice warning',
199 | 'red and white triangle with deer warning',
200 | 'white circle with gray strike bar no speed limit',
201 | 'blue circle with white right turn arrow mandatory',
202 | 'blue circle with white left turn arrow mandatory',
203 | 'blue circle with white forward arrow mandatory',
204 | 'blue circle with white forward or right turn arrow mandatory',
205 | 'blue circle with white forward or left turn arrow mandatory',
206 | 'blue circle with white keep right arrow mandatory',
207 | 'blue circle with white keep left arrow mandatory',
208 | 'blue circle with white arrows indicating a traffic circle',
209 | 'white circle with gray strike bar indicating no passing for cars has ended',
210 | 'white circle with gray strike bar indicating no passing for trucks has ended',
211 | ]
212 |
--------------------------------------------------------------------------------
/src/datasets/mnist.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import torchvision.datasets as datasets
4 |
5 | class MNIST:
6 | def __init__(self,
7 | preprocess,
8 | location=os.path.expanduser('~/data'),
9 | batch_size=128,
10 | num_workers=0):
11 |
12 |
13 | self.train_dataset = datasets.MNIST(
14 | root=location,
15 | download=True,
16 | train=True,
17 | transform=preprocess
18 | )
19 |
20 | self.train_loader = torch.utils.data.DataLoader(
21 | self.train_dataset,
22 | batch_size=batch_size,
23 | shuffle=True,
24 | num_workers=num_workers
25 | )
26 |
27 | self.test_dataset = datasets.MNIST(
28 | root=location,
29 | download=True,
30 | train=False,
31 | transform=preprocess
32 | )
33 |
34 | self.test_loader = torch.utils.data.DataLoader(
35 | self.test_dataset,
36 | batch_size=batch_size,
37 | shuffle=False,
38 | num_workers=num_workers
39 | )
40 |
41 | self.test_loader_shuffle = torch.utils.data.DataLoader(
42 | self.test_dataset,
43 | batch_size=batch_size,
44 | shuffle=True,
45 | num_workers=num_workers
46 | )
47 |
48 | self.classnames = ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9']
--------------------------------------------------------------------------------
/src/datasets/registry.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import inspect
3 | import random
4 | import torch
5 | import copy
6 |
7 | from torch.utils.data.dataset import random_split
8 |
9 | from datasets.cars import Cars
10 | from datasets.dtd import DTD
11 | from datasets.eurosat import EuroSAT, EuroSATVal
12 | from datasets.gtsrb import GTSRB
13 | from datasets.mnist import MNIST
14 | from datasets.resisc45 import RESISC45
15 | from datasets.svhn import SVHN
16 | from datasets.sun397 import SUN397
17 |
18 | registry = {
19 | name: obj for name, obj in inspect.getmembers(sys.modules[__name__], inspect.isclass)
20 | }
21 |
22 |
23 | class GenericDataset(object):
24 | def __init__(self):
25 | self.train_dataset = None
26 | self.train_loader = None
27 | self.test_dataset = None
28 | self.test_loader = None
29 | self.classnames = None
30 |
31 |
32 | def split_train_into_train_val(dataset, new_dataset_class_name, batch_size, num_workers, val_fraction, max_val_samples=None, seed=0):
33 | assert val_fraction > 0. and val_fraction < 1.
34 | total_size = len(dataset.train_dataset)
35 | val_size = int(total_size * val_fraction)
36 | if max_val_samples is not None:
37 | val_size = min(val_size, max_val_samples)
38 | train_size = total_size - val_size
39 |
40 | assert val_size > 0
41 | assert train_size > 0
42 |
43 | lengths = [train_size, val_size]
44 |
45 | trainset, valset = random_split(
46 | dataset.train_dataset,
47 | lengths,
48 | generator=torch.Generator().manual_seed(seed)
49 | )
50 | if new_dataset_class_name == 'MNISTVal':
51 | assert trainset.indices[0] == 36044
52 |
53 |
54 | new_dataset = None
55 |
56 | new_dataset_class = type(new_dataset_class_name, (GenericDataset, ), {})
57 | new_dataset = new_dataset_class()
58 |
59 | new_dataset.train_dataset = trainset
60 | new_dataset.train_loader = torch.utils.data.DataLoader(
61 | new_dataset.train_dataset,
62 | shuffle=True,
63 | batch_size=batch_size,
64 | num_workers=num_workers,
65 | )
66 |
67 | new_dataset.test_dataset = valset
68 | new_dataset.test_loader = torch.utils.data.DataLoader(
69 | new_dataset.test_dataset,
70 | batch_size=batch_size,
71 | num_workers=num_workers
72 | )
73 |
74 | new_dataset.test_loader_shuffle = torch.utils.data.DataLoader(
75 | new_dataset.test_dataset,
76 | shuffle=True,
77 | batch_size=batch_size,
78 | num_workers=num_workers
79 | )
80 |
81 | new_dataset.classnames = copy.copy(dataset.classnames)
82 |
83 | return new_dataset
84 |
85 |
86 | def get_dataset(dataset_name, preprocess, location, batch_size=128, num_workers=0, val_fraction=0.1, max_val_samples=5000):
87 | if dataset_name.endswith('Val'):
88 | # Handle val splits
89 | if dataset_name in registry:
90 | dataset_class = registry[dataset_name]
91 | else:
92 | base_dataset_name = dataset_name.split('Val')[0]
93 | base_dataset = get_dataset(base_dataset_name, preprocess, location, batch_size, num_workers)
94 | dataset = split_train_into_train_val(
95 | base_dataset, dataset_name, batch_size, num_workers, val_fraction, max_val_samples)
96 | return dataset
97 | else:
98 | assert dataset_name in registry, f'Unsupported dataset: {dataset_name}. Supported datasets: {list(registry.keys())}'
99 | dataset_class = registry[dataset_name]
100 | dataset = dataset_class(
101 | preprocess, location=location, batch_size=batch_size, num_workers=num_workers
102 | )
103 | return dataset
104 |
--------------------------------------------------------------------------------
/src/datasets/resisc45.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 |
4 | import abc
5 | import os
6 | from typing import Any, Callable, Dict, Optional, Tuple
7 |
8 | import numpy as np
9 | import torch
10 | from torch import Tensor
11 | from torch.utils.data import Dataset
12 | from torchvision.datasets import ImageFolder
13 | from torchvision.datasets.folder import default_loader as pil_loader
14 |
15 |
16 | # modified from: https://github.com/microsoft/torchgeo
17 | class VisionDataset(Dataset[Dict[str, Any]], abc.ABC):
18 | """Abstract base class for datasets lacking geospatial information.
19 | This base class is designed for datasets with pre-defined image chips.
20 | """
21 |
22 | @abc.abstractmethod
23 | def __getitem__(self, index: int) -> Dict[str, Any]:
24 | """Return an index within the dataset.
25 | Args:
26 | index: index to return
27 | Returns:
28 | data and labels at that index
29 | Raises:
30 | IndexError: if index is out of range of the dataset
31 | """
32 |
33 | @abc.abstractmethod
34 | def __len__(self) -> int:
35 | """Return the length of the dataset.
36 | Returns:
37 | length of the dataset
38 | """
39 |
40 | def __str__(self) -> str:
41 | """Return the informal string representation of the object.
42 | Returns:
43 | informal string representation
44 | """
45 | return f"""\
46 | {self.__class__.__name__} Dataset
47 | type: VisionDataset
48 | size: {len(self)}"""
49 |
50 |
51 | class VisionClassificationDataset(VisionDataset, ImageFolder):
52 | """Abstract base class for classification datasets lacking geospatial information.
53 | This base class is designed for datasets with pre-defined image chips which
54 | are separated into separate folders per class.
55 | """
56 |
57 | def __init__(
58 | self,
59 | root: str,
60 | transforms: Optional[Callable[[Dict[str, Tensor]], Dict[str, Tensor]]] = None,
61 | loader: Optional[Callable[[str], Any]] = pil_loader,
62 | is_valid_file: Optional[Callable[[str], bool]] = None,
63 | ) -> None:
64 | """Initialize a new VisionClassificationDataset instance.
65 | Args:
66 | root: root directory where dataset can be found
67 | transforms: a function/transform that takes input sample and its target as
68 | entry and returns a transformed version
69 | loader: a callable function which takes as input a path to an image and
70 | returns a PIL Image or numpy array
71 | is_valid_file: A function that takes the path of an Image file and checks if
72 | the file is a valid file
73 | """
74 | # When transform & target_transform are None, ImageFolder.__getitem__(index)
75 | # returns a PIL.Image and int for image and label, respectively
76 | super().__init__(
77 | root=root,
78 | transform=None,
79 | target_transform=None,
80 | loader=loader,
81 | is_valid_file=is_valid_file,
82 | )
83 |
84 | # Must be set after calling super().__init__()
85 | self.transforms = transforms
86 |
87 | def __getitem__(self, index: int) -> Dict[str, Tensor]:
88 | """Return an index within the dataset.
89 | Args:
90 | index: index to return
91 | Returns:
92 | data and label at that index
93 | """
94 | image, label = self._load_image(index)
95 |
96 | if self.transforms is not None:
97 | return self.transforms(image), label
98 |
99 | return image, label
100 |
101 | def __len__(self) -> int:
102 | """Return the number of data points in the dataset.
103 | Returns:
104 | length of the dataset
105 | """
106 | return len(self.imgs)
107 |
108 | def _load_image(self, index: int) -> Tuple[Tensor, Tensor]:
109 | """Load a single image and it's class label.
110 | Args:
111 | index: index to return
112 | Returns:
113 | the image
114 | the image class label
115 | """
116 | img, label = ImageFolder.__getitem__(self, index)
117 | label = torch.tensor(label)
118 | return img, label
119 |
120 |
121 | class RESISC45Dataset(VisionClassificationDataset):
122 | """RESISC45 dataset.
123 | The `RESISC45 `_
124 | dataset is a dataset for remote sensing image scene classification.
125 | Dataset features:
126 | * 31,500 images with 0.2-30 m per pixel resolution (256x256 px)
127 | * three spectral bands - RGB
128 | * 45 scene classes, 700 images per class
129 | * images extracted from Google Earth from over 100 countries
130 | * images conditions with high variability (resolution, weather, illumination)
131 | Dataset format:
132 | * images are three-channel jpgs
133 | Dataset classes:
134 | 0. airplane
135 | 1. airport
136 | 2. baseball_diamond
137 | 3. basketball_court
138 | 4. beach
139 | 5. bridge
140 | 6. chaparral
141 | 7. church
142 | 8. circular_farmland
143 | 9. cloud
144 | 10. commercial_area
145 | 11. dense_residential
146 | 12. desert
147 | 13. forest
148 | 14. freeway
149 | 15. golf_course
150 | 16. ground_track_field
151 | 17. harbor
152 | 18. industrial_area
153 | 19. intersection
154 | 20. island
155 | 21. lake
156 | 22. meadow
157 | 23. medium_residential
158 | 24. mobile_home_park
159 | 25. mountain
160 | 26. overpass
161 | 27. palace
162 | 28. parking_lot
163 | 29. railway
164 | 30. railway_station
165 | 31. rectangular_farmland
166 | 32. river
167 | 33. roundabout
168 | 34. runway
169 | 35. sea_ice
170 | 36. ship
171 | 37. snowberg
172 | 38. sparse_residential
173 | 39. stadium
174 | 40. storage_tank
175 | 41. tennis_court
176 | 42. terrace
177 | 43. thermal_power_station
178 | 44. wetland
179 | This dataset uses the train/val/test splits defined in the "In-domain representation
180 | learning for remote sensing" paper:
181 | * https://arxiv.org/abs/1911.06721
182 | If you use this dataset in your research, please cite the following paper:
183 | * https://doi.org/10.1109/jproc.2017.2675998
184 | """
185 |
186 | # url = "https://drive.google.com/file/d/1DnPSU5nVSN7xv95bpZ3XQ0JhKXZOKgIv"
187 | # md5 = "d824acb73957502b00efd559fc6cfbbb"
188 | # filename = "NWPU-RESISC45.rar"
189 | directory = "resisc45/NWPU-RESISC45"
190 |
191 | splits = ["train", "val", "test"]
192 | split_urls = {
193 | "train": "https://storage.googleapis.com/remote_sensing_representations/resisc45-train.txt", # noqa: E501
194 | "val": "https://storage.googleapis.com/remote_sensing_representations/resisc45-val.txt", # noqa: E501
195 | "test": "https://storage.googleapis.com/remote_sensing_representations/resisc45-test.txt", # noqa: E501
196 | }
197 | split_md5s = {
198 | "train": "b5a4c05a37de15e4ca886696a85c403e",
199 | "val": "a0770cee4c5ca20b8c32bbd61e114805",
200 | "test": "3dda9e4988b47eb1de9f07993653eb08",
201 | }
202 | classes = [
203 | "airplane",
204 | "airport",
205 | "baseball_diamond",
206 | "basketball_court",
207 | "beach",
208 | "bridge",
209 | "chaparral",
210 | "church",
211 | "circular_farmland",
212 | "cloud",
213 | "commercial_area",
214 | "dense_residential",
215 | "desert",
216 | "forest",
217 | "freeway",
218 | "golf_course",
219 | "ground_track_field",
220 | "harbor",
221 | "industrial_area",
222 | "intersection",
223 | "island",
224 | "lake",
225 | "meadow",
226 | "medium_residential",
227 | "mobile_home_park",
228 | "mountain",
229 | "overpass",
230 | "palace",
231 | "parking_lot",
232 | "railway",
233 | "railway_station",
234 | "rectangular_farmland",
235 | "river",
236 | "roundabout",
237 | "runway",
238 | "sea_ice",
239 | "ship",
240 | "snowberg",
241 | "sparse_residential",
242 | "stadium",
243 | "storage_tank",
244 | "tennis_court",
245 | "terrace",
246 | "thermal_power_station",
247 | "wetland",
248 | ]
249 |
250 | def __init__(
251 | self,
252 | root: str = "data",
253 | split: str = "train",
254 | transforms: Optional[Callable[[Dict[str, Tensor]], Dict[str, Tensor]]] = None,
255 | ) -> None:
256 | """Initialize a new RESISC45 dataset instance.
257 | Args:
258 | root: root directory where dataset can be found
259 | split: one of "train", "val", or "test"
260 | transforms: a function/transform that takes input sample and its target as
261 | entry and returns a transformed version
262 | """
263 | assert split in self.splits
264 | self.root = root
265 |
266 | valid_fns = set()
267 | with open(os.path.join(self.root, "resisc45", f"resisc45-{split}.txt")) as f:
268 | for fn in f:
269 | valid_fns.add(fn.strip())
270 | is_in_split: Callable[[str], bool] = lambda x: os.path.basename(
271 | x) in valid_fns
272 |
273 | super().__init__(
274 | root=os.path.join(root, self.directory),
275 | transforms=transforms,
276 | is_valid_file=is_in_split,
277 | )
278 |
279 |
280 |
281 | class RESISC45:
282 | def __init__(self,
283 | preprocess,
284 | location=os.path.expanduser('~/data'),
285 | batch_size=32,
286 | num_workers=0):
287 |
288 | self.train_dataset = RESISC45Dataset(root=location, split='train', transforms=preprocess)
289 | self.train_loader = torch.utils.data.DataLoader(
290 | self.train_dataset,
291 | shuffle=True,
292 | batch_size=batch_size,
293 | num_workers=num_workers,
294 | )
295 |
296 | self.test_dataset = RESISC45Dataset(root=location, split='test', transforms=preprocess)
297 | self.test_loader = torch.utils.data.DataLoader(
298 | self.test_dataset,
299 | batch_size=batch_size,
300 | num_workers=num_workers
301 | )
302 | self.test_loader_shuffle = torch.utils.data.DataLoader(
303 | self.test_dataset,
304 | shuffle=True,
305 | batch_size=batch_size,
306 | num_workers=num_workers
307 | )
308 |
309 | # class names have _ so split on this for better zero-shot head
310 | self.classnames = [' '.join(c.split('_')) for c in RESISC45Dataset.classes]
311 |
--------------------------------------------------------------------------------
/src/datasets/sun397.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import torchvision.datasets as datasets
4 |
5 | class SUN397:
6 | def __init__(self,
7 | preprocess,
8 | location=os.path.expanduser('~/data'),
9 | batch_size=32,
10 | num_workers=0):
11 | # Data loading code
12 | traindir = os.path.join(location, 'sun397', 'train')
13 | valdir = os.path.join(location, 'sun397', 'test')
14 |
15 |
16 | self.train_dataset = datasets.ImageFolder(traindir, transform=preprocess)
17 | self.train_loader = torch.utils.data.DataLoader(
18 | self.train_dataset,
19 | shuffle=True,
20 | batch_size=batch_size,
21 | num_workers=num_workers,
22 | )
23 |
24 | self.test_dataset = datasets.ImageFolder(valdir, transform=preprocess)
25 | self.test_loader = torch.utils.data.DataLoader(
26 | self.test_dataset,
27 | batch_size=batch_size,
28 | num_workers=num_workers
29 | )
30 | self.test_loader_shuffle = torch.utils.data.DataLoader(
31 | self.test_dataset,
32 | shuffle=True,
33 | batch_size=batch_size,
34 | num_workers=num_workers
35 | )
36 | idx_to_class = dict((v, k)
37 | for k, v in self.train_dataset.class_to_idx.items())
38 | self.classnames = [idx_to_class[i][2:].replace('_', ' ') for i in range(len(idx_to_class))]
39 |
--------------------------------------------------------------------------------
/src/datasets/svhn.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | from torchvision.datasets import SVHN as PyTorchSVHN
4 | import numpy as np
5 |
6 |
7 | class SVHN:
8 | def __init__(self,
9 | preprocess,
10 | location=os.path.expanduser('~/data'),
11 | batch_size=128,
12 | num_workers=0):
13 |
14 | # to fit with repo conventions for location
15 | modified_location = os.path.join(location, 'svhn')
16 |
17 | self.train_dataset = PyTorchSVHN(
18 | root=modified_location,
19 | download=True,
20 | split='train',
21 | transform=preprocess
22 | )
23 |
24 | self.train_loader = torch.utils.data.DataLoader(
25 | self.train_dataset,
26 | batch_size=batch_size,
27 | shuffle=True,
28 | num_workers=num_workers
29 | )
30 |
31 | self.test_dataset = PyTorchSVHN(
32 | root=modified_location,
33 | download=True,
34 | split='test',
35 | transform=preprocess
36 | )
37 |
38 | self.test_loader = torch.utils.data.DataLoader(
39 | self.test_dataset,
40 | batch_size=batch_size,
41 | shuffle=False,
42 | num_workers=num_workers
43 | )
44 | self.test_loader_shuffle = torch.utils.data.DataLoader(
45 | self.test_dataset,
46 | batch_size=batch_size,
47 | shuffle=True,
48 | num_workers=num_workers
49 | )
50 |
51 | self.classnames = ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9']
52 |
--------------------------------------------------------------------------------
/src/datasets/templates.py:
--------------------------------------------------------------------------------
1 | cars_template = [
2 | lambda c: f'a photo of a {c}.',
3 | lambda c: f'a photo of the {c}.',
4 | lambda c: f'a photo of my {c}.',
5 | lambda c: f'i love my {c}!',
6 | lambda c: f'a photo of my dirty {c}.',
7 | lambda c: f'a photo of my clean {c}.',
8 | lambda c: f'a photo of my new {c}.',
9 | lambda c: f'a photo of my old {c}.',
10 | ]
11 |
12 | cifar10_template = [
13 | lambda c: f'a photo of a {c}.',
14 | lambda c: f'a blurry photo of a {c}.',
15 | lambda c: f'a black and white photo of a {c}.',
16 | lambda c: f'a low contrast photo of a {c}.',
17 | lambda c: f'a high contrast photo of a {c}.',
18 | lambda c: f'a bad photo of a {c}.',
19 | lambda c: f'a good photo of a {c}.',
20 | lambda c: f'a photo of a small {c}.',
21 | lambda c: f'a photo of a big {c}.',
22 | lambda c: f'a photo of the {c}.',
23 | lambda c: f'a blurry photo of the {c}.',
24 | lambda c: f'a black and white photo of the {c}.',
25 | lambda c: f'a low contrast photo of the {c}.',
26 | lambda c: f'a high contrast photo of the {c}.',
27 | lambda c: f'a bad photo of the {c}.',
28 | lambda c: f'a good photo of the {c}.',
29 | lambda c: f'a photo of the small {c}.',
30 | lambda c: f'a photo of the big {c}.',
31 | ]
32 |
33 | cifar100_template = [
34 | lambda c: f'a photo of a {c}.',
35 | lambda c: f'a blurry photo of a {c}.',
36 | lambda c: f'a black and white photo of a {c}.',
37 | lambda c: f'a low contrast photo of a {c}.',
38 | lambda c: f'a high contrast photo of a {c}.',
39 | lambda c: f'a bad photo of a {c}.',
40 | lambda c: f'a good photo of a {c}.',
41 | lambda c: f'a photo of a small {c}.',
42 | lambda c: f'a photo of a big {c}.',
43 | lambda c: f'a photo of the {c}.',
44 | lambda c: f'a blurry photo of the {c}.',
45 | lambda c: f'a black and white photo of the {c}.',
46 | lambda c: f'a low contrast photo of the {c}.',
47 | lambda c: f'a high contrast photo of the {c}.',
48 | lambda c: f'a bad photo of the {c}.',
49 | lambda c: f'a good photo of the {c}.',
50 | lambda c: f'a photo of the small {c}.',
51 | lambda c: f'a photo of the big {c}.',
52 | ]
53 |
54 | dtd_template = [
55 | lambda c: f'a photo of a {c} texture.',
56 | lambda c: f'a photo of a {c} pattern.',
57 | lambda c: f'a photo of a {c} thing.',
58 | lambda c: f'a photo of a {c} object.',
59 | lambda c: f'a photo of the {c} texture.',
60 | lambda c: f'a photo of the {c} pattern.',
61 | lambda c: f'a photo of the {c} thing.',
62 | lambda c: f'a photo of the {c} object.',
63 | ]
64 |
65 | eurosat_template = [
66 | lambda c: f'a centered satellite photo of {c}.',
67 | lambda c: f'a centered satellite photo of a {c}.',
68 | lambda c: f'a centered satellite photo of the {c}.',
69 | ]
70 |
71 | food101_template = [
72 | lambda c: f'a photo of {c}, a type of food.',
73 | ]
74 |
75 | gtsrb_template = [
76 | lambda c: f'a zoomed in photo of a "{c}" traffic sign.',
77 | lambda c: f'a centered photo of a "{c}" traffic sign.',
78 | lambda c: f'a close up photo of a "{c}" traffic sign.',
79 | ]
80 |
81 | mnist_template = [
82 | lambda c: f'a photo of the number: "{c}".',
83 | ]
84 |
85 | imagenet_template = [
86 | lambda c: f'a bad photo of a {c}.',
87 | lambda c: f'a photo of many {c}.',
88 | lambda c: f'a sculpture of a {c}.',
89 | lambda c: f'a photo of the hard to see {c}.',
90 | lambda c: f'a low resolution photo of the {c}.',
91 | lambda c: f'a rendering of a {c}.',
92 | lambda c: f'graffiti of a {c}.',
93 | lambda c: f'a bad photo of the {c}.',
94 | lambda c: f'a cropped photo of the {c}.',
95 | lambda c: f'a tattoo of a {c}.',
96 | lambda c: f'the embroidered {c}.',
97 | lambda c: f'a photo of a hard to see {c}.',
98 | lambda c: f'a bright photo of a {c}.',
99 | lambda c: f'a photo of a clean {c}.',
100 | lambda c: f'a photo of a dirty {c}.',
101 | lambda c: f'a dark photo of the {c}.',
102 | lambda c: f'a drawing of a {c}.',
103 | lambda c: f'a photo of my {c}.',
104 | lambda c: f'the plastic {c}.',
105 | lambda c: f'a photo of the cool {c}.',
106 | lambda c: f'a close-up photo of a {c}.',
107 | lambda c: f'a black and white photo of the {c}.',
108 | lambda c: f'a painting of the {c}.',
109 | lambda c: f'a painting of a {c}.',
110 | lambda c: f'a pixelated photo of the {c}.',
111 | lambda c: f'a sculpture of the {c}.',
112 | lambda c: f'a bright photo of the {c}.',
113 | lambda c: f'a cropped photo of a {c}.',
114 | lambda c: f'a plastic {c}.',
115 | lambda c: f'a photo of the dirty {c}.',
116 | lambda c: f'a jpeg corrupted photo of a {c}.',
117 | lambda c: f'a blurry photo of the {c}.',
118 | lambda c: f'a photo of the {c}.',
119 | lambda c: f'a good photo of the {c}.',
120 | lambda c: f'a rendering of the {c}.',
121 | lambda c: f'a {c} in a video game.',
122 | lambda c: f'a photo of one {c}.',
123 | lambda c: f'a doodle of a {c}.',
124 | lambda c: f'a close-up photo of the {c}.',
125 | lambda c: f'a photo of a {c}.',
126 | lambda c: f'the origami {c}.',
127 | lambda c: f'the {c} in a video game.',
128 | lambda c: f'a sketch of a {c}.',
129 | lambda c: f'a doodle of the {c}.',
130 | lambda c: f'a origami {c}.',
131 | lambda c: f'a low resolution photo of a {c}.',
132 | lambda c: f'the toy {c}.',
133 | lambda c: f'a rendition of the {c}.',
134 | lambda c: f'a photo of the clean {c}.',
135 | lambda c: f'a photo of a large {c}.',
136 | lambda c: f'a rendition of a {c}.',
137 | lambda c: f'a photo of a nice {c}.',
138 | lambda c: f'a photo of a weird {c}.',
139 | lambda c: f'a blurry photo of a {c}.',
140 | lambda c: f'a cartoon {c}.',
141 | lambda c: f'art of a {c}.',
142 | lambda c: f'a sketch of the {c}.',
143 | lambda c: f'a embroidered {c}.',
144 | lambda c: f'a pixelated photo of a {c}.',
145 | lambda c: f'itap of the {c}.',
146 | lambda c: f'a jpeg corrupted photo of the {c}.',
147 | lambda c: f'a good photo of a {c}.',
148 | lambda c: f'a plushie {c}.',
149 | lambda c: f'a photo of the nice {c}.',
150 | lambda c: f'a photo of the small {c}.',
151 | lambda c: f'a photo of the weird {c}.',
152 | lambda c: f'the cartoon {c}.',
153 | lambda c: f'art of the {c}.',
154 | lambda c: f'a drawing of the {c}.',
155 | lambda c: f'a photo of the large {c}.',
156 | lambda c: f'a black and white photo of a {c}.',
157 | lambda c: f'the plushie {c}.',
158 | lambda c: f'a dark photo of a {c}.',
159 | lambda c: f'itap of a {c}.',
160 | lambda c: f'graffiti of the {c}.',
161 | lambda c: f'a toy {c}.',
162 | lambda c: f'itap of my {c}.',
163 | lambda c: f'a photo of a cool {c}.',
164 | lambda c: f'a photo of a small {c}.',
165 | lambda c: f'a tattoo of the {c}.',
166 | ]
167 |
168 | resisc45_template = [
169 | lambda c: f'satellite imagery of {c}.',
170 | lambda c: f'aerial imagery of {c}.',
171 | lambda c: f'satellite photo of {c}.',
172 | lambda c: f'aerial photo of {c}.',
173 | lambda c: f'satellite view of {c}.',
174 | lambda c: f'aerial view of {c}.',
175 | lambda c: f'satellite imagery of a {c}.',
176 | lambda c: f'aerial imagery of a {c}.',
177 | lambda c: f'satellite photo of a {c}.',
178 | lambda c: f'aerial photo of a {c}.',
179 | lambda c: f'satellite view of a {c}.',
180 | lambda c: f'aerial view of a {c}.',
181 | lambda c: f'satellite imagery of the {c}.',
182 | lambda c: f'aerial imagery of the {c}.',
183 | lambda c: f'satellite photo of the {c}.',
184 | lambda c: f'aerial photo of the {c}.',
185 | lambda c: f'satellite view of the {c}.',
186 | lambda c: f'aerial view of the {c}.',
187 | ]
188 |
189 | stl10_template = [
190 | lambda c: f'a photo of a {c}.',
191 | lambda c: f'a photo of the {c}.',
192 | ]
193 |
194 | sun397_template = [
195 | lambda c: f'a photo of a {c}.',
196 | lambda c: f'a photo of the {c}.',
197 | ]
198 |
199 | svhn_template = [
200 | lambda c: f'a photo of the number: "{c}".',
201 | ]
202 |
203 |
204 | dataset_to_template = {
205 | 'Cars': cars_template,
206 | 'CIFAR10': cifar10_template,
207 | 'CIFAR100': cifar100_template,
208 | 'DTD': dtd_template,
209 | 'EuroSAT': eurosat_template,
210 | 'Food101': food101_template,
211 | 'GTSRB': gtsrb_template,
212 | 'MNIST': mnist_template,
213 | 'ImageNet': imagenet_template,
214 | 'RESISC45': resisc45_template,
215 | 'STL10': stl10_template,
216 | 'SUN397': sun397_template,
217 | 'SVHN': svhn_template,
218 | }
219 |
220 |
221 | def get_templates(dataset_name):
222 | if dataset_name.endswith('Val'):
223 | return get_templates(dataset_name.replace('Val', ''))
224 | assert dataset_name in dataset_to_template, f'Unsupported dataset: {dataset_name}'
225 | return dataset_to_template[dataset_name]
--------------------------------------------------------------------------------
/src/eval.py:
--------------------------------------------------------------------------------
1 | import os
2 | import json
3 | import tqdm
4 |
5 | import torch
6 | import numpy as np
7 |
8 | import utils
9 | from datasets.common import get_dataloader, maybe_dictionarize
10 | from heads import get_classification_head
11 | from modeling import ImageClassifier
12 |
13 | from datasets.registry import get_dataset
14 |
15 | def eval_single_dataset(image_encoder, dataset_name, args):
16 | classification_head = get_classification_head(args, dataset_name)
17 | model = ImageClassifier(image_encoder, classification_head)
18 |
19 | model.eval()
20 |
21 | dataset = get_dataset(
22 | dataset_name,
23 | model.val_preprocess,
24 | location=args.data_location,
25 | batch_size=args.batch_size
26 | )
27 | dataloader = get_dataloader(
28 | dataset, is_train=False, args=args, image_encoder=None)
29 | device = args.device
30 |
31 | with torch.no_grad():
32 | top1, correct, n = 0., 0., 0.
33 | for i, data in enumerate(tqdm.tqdm(dataloader)):
34 | data = maybe_dictionarize(data)
35 | x = data['images'].to(device)
36 | y = data['labels'].to(device)
37 |
38 | logits = utils.get_logits(x, model)
39 |
40 | pred = logits.argmax(dim=1, keepdim=True).to(device)
41 |
42 | correct += pred.eq(y.view_as(pred)).sum().item()
43 |
44 | n += y.size(0)
45 |
46 | top1 = correct / n
47 |
48 | metrics = {'top1': top1}
49 | print(f'Done evaluating on {dataset_name}. Accuracy: {100*top1:.2f}%')
50 |
51 | return metrics
52 |
53 | def eval_single_dataset_head(image_encoder, head, dataset_name, args):
54 | model = ImageClassifier(image_encoder, head)
55 |
56 | model.eval()
57 |
58 | dataset = get_dataset(dataset_name, model.val_preprocess, location=args.data_location, batch_size=args.batch_size)
59 | dataloader = get_dataloader(dataset, is_train=False, args=args, image_encoder=None)
60 | device = args.device
61 |
62 | with torch.no_grad():
63 | top1, correct, n = 0., 0., 0.
64 | for i, data in enumerate(tqdm.tqdm(dataloader)):
65 | data = maybe_dictionarize(data)
66 | x = data['images'].to(device)
67 | y = data['labels'].to(device)
68 |
69 | logits = utils.get_logits(x, model)
70 |
71 | pred = logits.argmax(dim=1, keepdim=True).to(device)
72 |
73 | correct += pred.eq(y.view_as(pred)).sum().item()
74 |
75 | n += y.size(0)
76 |
77 | top1 = correct / n
78 |
79 | metrics = {'top1': top1}
80 | print(f'Done evaluating on {dataset_name}. Accuracy: {100 * top1:.2f}%')
81 |
82 | return metrics
83 |
84 | def eval_single_dataset_preprocess_head(image_encoder, head, dataset_name, args):
85 | model = ImageClassifier(image_encoder, head)
86 |
87 | model.eval()
88 |
89 | dataset = get_dataset(dataset_name, model.val_preprocess, location=args.data_location, batch_size=args.batch_size)
90 | dataloader = get_dataloader(dataset, is_train=False, args=args, image_encoder=None)
91 | device = args.device
92 |
93 | with torch.no_grad():
94 | top1, correct, n = 0., 0., 0.
95 | for i, data in enumerate(tqdm.tqdm(dataloader)):
96 | data = maybe_dictionarize(data)
97 | x = data['images'].to(device)
98 | y = data['labels'].to(device)
99 |
100 | logits = utils.get_logits(x, model)
101 |
102 | pred = logits.argmax(dim=1, keepdim=True).to(device)
103 |
104 | correct += pred.eq(y.view_as(pred)).sum().item()
105 |
106 | n += y.size(0)
107 |
108 | top1 = correct / n
109 |
110 | metrics = {'top1': top1}
111 | print(f'Done evaluating on {dataset_name}. Accuracy: {100 * top1:.2f}%')
112 |
113 | return metrics
114 |
115 | def evaluate(image_encoder, args):
116 | if args.eval_datasets is None:
117 | return
118 | info = vars(args)
119 | for i, dataset_name in enumerate(args.eval_datasets):
120 | print('Evaluating on', dataset_name)
121 |
122 | results = eval_single_dataset(image_encoder, dataset_name, args)
123 |
124 | if 'top1' in results:
125 | print(f"{dataset_name} Top-1 accuracy: {results['top1']:.4f}")
126 | for key, val in results.items():
127 | if 'worst' in key or 'f1' in key.lower() or 'pm0' in key:
128 | print(f"{dataset_name} {key}: {val:.4f}")
129 | info[dataset_name + ':' + key] = val
130 |
131 | if args.results_db is not None:
132 | dirname = os.path.dirname(args.results_db)
133 | if dirname:
134 | os.makedirs(dirname, exist_ok=True)
135 | with open(args.results_db, 'a+') as f:
136 | f.write(json.dumps(info) + '\n')
137 | print(f'Results saved to {args.results_db}.')
138 | else:
139 | print('Results not saved (to do so, use --results_db to specify a path).')
140 |
141 | return info
--------------------------------------------------------------------------------
/src/heads.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | from tqdm import tqdm
4 |
5 | import open_clip
6 |
7 | from datasets.templates import get_templates
8 | from datasets.registry import get_dataset
9 |
10 | from modeling import ClassificationHead, ImageEncoder
11 |
12 |
13 | def build_classification_head(model, dataset_name, template, data_location, device):
14 | template = get_templates(dataset_name)
15 |
16 | logit_scale = model.logit_scale
17 | dataset = get_dataset(
18 | dataset_name,
19 | None,
20 | location=data_location
21 | )
22 | model.eval()
23 | model.to(device)
24 |
25 | print('Building classification head.')
26 | with torch.no_grad():
27 | zeroshot_weights = []
28 | for classname in tqdm(dataset.classnames):
29 | texts = []
30 | for t in template:
31 | texts.append(t(classname))
32 | texts = open_clip.tokenize(texts).to(device) # tokenize
33 | embeddings = model.encode_text(texts) # embed with text encoder
34 | embeddings /= embeddings.norm(dim=-1, keepdim=True)
35 |
36 | embeddings = embeddings.mean(dim=0, keepdim=True)
37 | embeddings /= embeddings.norm()
38 |
39 | zeroshot_weights.append(embeddings)
40 |
41 | zeroshot_weights = torch.stack(zeroshot_weights, dim=0).to(device)
42 | zeroshot_weights = torch.transpose(zeroshot_weights, 0, 2)
43 |
44 | zeroshot_weights *= logit_scale.exp()
45 |
46 | zeroshot_weights = zeroshot_weights.squeeze().float()
47 | zeroshot_weights = torch.transpose(zeroshot_weights, 0, 1)
48 |
49 | classification_head = ClassificationHead(normalize=True, weights=zeroshot_weights)
50 |
51 | return classification_head
52 |
53 |
54 | def get_classification_head(args, dataset):
55 | filename = os.path.join(args.save, f'head_{dataset}.pt')
56 | if os.path.exists(filename):
57 | print(f'Classification head for {args.model} on {dataset} exists at {filename}')
58 | return ClassificationHead.load(filename)
59 | print(f'Did not find classification head for {args.model} on {dataset} at {filename}, building one from scratch.')
60 | model = ImageEncoder(args, keep_lang=True).model
61 | template = get_templates(dataset)
62 | classification_head = build_classification_head(model, dataset, template, args.data_location, args.device)
63 | os.makedirs(args.save, exist_ok=True)
64 | classification_head.save(filename)
65 | return classification_head
66 |
67 |
--------------------------------------------------------------------------------
/src/main_layer_wise_adamerging.py:
--------------------------------------------------------------------------------
1 | import os
2 | os.environ["CUDA_VISIBLE_DEVICES"] = "3"
3 |
4 | import time
5 | import sys
6 | import tqdm
7 | sys.path.append('/home/taskarithmetic/')
8 |
9 | import torch
10 | from task_vectors import TaskVector
11 | from eval import eval_single_dataset, eval_single_dataset_head, eval_single_dataset_preprocess_head
12 | from args import parse_arguments
13 | def create_log_dir(path, filename='log.txt'):
14 | import logging
15 | if not os.path.exists(path):
16 | os.makedirs(path)
17 | logger = logging.getLogger(path)
18 | logger.setLevel(logging.DEBUG)
19 | fh = logging.FileHandler(path+'/'+filename)
20 | fh.setLevel(logging.DEBUG)
21 | ch = logging.StreamHandler()
22 | ch.setLevel(logging.DEBUG)
23 | logger.addHandler(fh)
24 | logger.addHandler(ch)
25 | return logger
26 |
27 | exam_datasets = ['SUN397', 'Cars', 'RESISC45', 'EuroSAT', 'SVHN', 'GTSRB', 'MNIST', 'DTD'] # SUN397 | Cars | RESISC45 | EuroSAT | SVHN | GTSRB | MNIST | DTD
28 | model = 'ViT-B-32'
29 | args = parse_arguments()
30 | args.data_location = '/home/taskarithmetic/data'
31 | args.model = model
32 | args.save = '/home/taskarithmetic/checkpoints/' + model
33 | args.logs_path = '/home/taskarithmetic/logs/' + model
34 | pretrained_checkpoint = '/home/taskarithmetic/checkpoints/'+model+'/zeroshot.pt'
35 |
36 | str_time_ = time.strftime('%Y%m%d_%H%M%S', time.localtime(time.time()))
37 | log = create_log_dir(args.logs_path, 'log_{}_Layer_wise_AdaMerging.txt'.format(str_time_))
38 | args.log = log
39 |
40 | task_vectors = [TaskVector(pretrained_checkpoint, '/home/taskarithmetic/checkpoints/'+model+'/'+dataset_name+'/finetuned.pt') for dataset_name in exam_datasets]
41 |
42 | def del_attr(obj, names):
43 | if len(names) == 1:
44 | delattr(obj, names[0])
45 | else:
46 | del_attr(getattr(obj, names[0]), names[1:])
47 |
48 | def set_attr(obj, names, val):
49 | if len(names) == 1:
50 | setattr(obj, names[0], val)
51 | else:
52 | set_attr(getattr(obj, names[0]), names[1:], val)
53 |
54 | def make_functional(mod):
55 | orig_params = tuple(mod.parameters())
56 | names = []
57 | for name, p in list(mod.named_parameters()):
58 | del_attr(mod, name.split("."))
59 | names.append(name)
60 | return orig_params, names
61 |
62 | def load_weights(mod, names, params):
63 | for name, p in zip(names, params):
64 | set_attr(mod, name.split("."), p)
65 |
66 | class ModelWrapper(torch.nn.Module):
67 | def __init__(self, model, initial_weights=None):
68 | super(ModelWrapper, self).__init__()
69 | self.model = model
70 |
71 | if hasattr(self.model, 'transformer'):
72 | delattr(self.model, 'transformer')
73 |
74 | def forward(self, images):
75 | features = self.model(images)
76 | return features
77 |
78 | from heads import get_classification_head
79 | class AdaMerging(torch.nn.Module):
80 | def __init__(self, paramslist, model, names, exam_datasets):
81 | super(AdaMerging, self).__init__()
82 | self.paramslist = paramslist
83 | self.model = model
84 | self.names = names
85 | self.pretrain_lambdas = torch.ones(len(paramslist[0]), 1)
86 | prior = 0.3
87 | rlambdas = torch.ones(len(paramslist[0]), len(paramslist)-1) * prior # (1 * tasks)
88 | self.lambdas_raw = torch.nn.Parameter(rlambdas)
89 |
90 | self.classifier = []
91 | for dataset_name in exam_datasets:
92 | classification_head = get_classification_head(args, dataset_name)
93 | layer_name = 'classifier_{}'.format(dataset_name)
94 | self.add_module(layer_name, classification_head.to(args.device))
95 | self.classifier.append(layer_name)
96 |
97 | def lambdas(self):
98 | task_lambdas = torch.clamp(self.lambdas_raw, min=0.0, max=1.0)
99 | lambdass = torch.cat((self.pretrain_lambdas, task_lambdas), 1)
100 | return lambdass
101 |
102 | def collect_trainable_params(self):
103 | return [self.lambdas_raw]
104 |
105 | def get_classification_head(self, dataset_name):
106 | layer_name = 'classifier_{}'.format(dataset_name)
107 | classification_head = getattr(self, layer_name)
108 | return classification_head
109 |
110 | def get_image_encoder(self):
111 | alph = self.lambdas()
112 | params = tuple(sum(tuple(pi * lambdasi for pi, lambdasi in zip(p, alph[j].cpu()))) for j, p in enumerate(zip(*self.paramslist)))
113 | params = tuple(p.cuda(0) for p in params)
114 | load_weights(self.model, self.names, params)
115 | return self.model
116 |
117 | def forward(self, inp, dataset_name):
118 | alph = self.lambdas()
119 | params = tuple(sum(tuple(pi * lambdasi for pi, lambdasi in zip(p, alph[j].cpu()))) for j, p in enumerate(zip(*self.paramslist)))
120 |
121 | params = tuple(p.cuda(0) for p in params)
122 | load_weights(self.model, self.names, params)
123 | feature = self.model(inp)
124 |
125 | layer_name = 'classifier_{}'.format(dataset_name)
126 | classification_head = getattr(self, layer_name)
127 | out = classification_head(feature)
128 | return out
129 |
130 | def softmax_entropy(x):
131 | return -(x.softmax(1) * x.log_softmax(1)).sum(1)
132 |
133 | pretrained_model = torch.load(pretrained_checkpoint)
134 | pretrained_model_dic = pretrained_model.state_dict()
135 |
136 | model = ModelWrapper(pretrained_model, exam_datasets)
137 | model = model.to(args.device)
138 | _, names = make_functional(model)
139 |
140 | paramslist = []
141 | paramslist += [tuple(v.detach().requires_grad_().cpu() for _, v in pretrained_model_dic.items())] # pretrain
142 | paramslist += [tuple(v.detach().requires_grad_().cpu() for _, v in tv.vector.items()) for i, tv in enumerate(task_vectors)] # task vectors
143 | torch.cuda.empty_cache()
144 | adamerging_mtl_model = AdaMerging(paramslist, model, names, exam_datasets)
145 |
146 | print('init lambda:')
147 | print(adamerging_mtl_model.lambdas())
148 | print('collect_trainable_params:')
149 | print(list(adamerging_mtl_model.collect_trainable_params()))
150 |
151 | epochs = 500
152 | optimizer = torch.optim.Adam(adamerging_mtl_model.collect_trainable_params(), lr=1e-3, betas=(0.9, 0.999), weight_decay=0.)
153 |
154 | from datasets.registry import get_dataset
155 | from datasets.common import get_dataloader, maybe_dictionarize, get_dataloader_shuffle
156 |
157 | Total_ACC = 0.
158 | for dataset_name in exam_datasets:
159 | image_encoder = adamerging_mtl_model.get_image_encoder()
160 | classification_head = adamerging_mtl_model.get_classification_head(dataset_name)
161 | metrics = eval_single_dataset_preprocess_head(image_encoder, classification_head, dataset_name, args)
162 | Total_ACC += metrics['top1']
163 | log.info('Eval: init: ' + ' dataset: ' + str(dataset_name) + ' ACC: ' + str(metrics['top1']))
164 | log.info('Eval: init: ' + ' Avg ACC:' + str(Total_ACC / len(exam_datasets)) + '\n')
165 |
166 | for epoch in range(epochs):
167 | losses = 0.
168 | for dataset_name in exam_datasets:
169 | dataset = get_dataset(dataset_name, pretrained_model.val_preprocess, location=args.data_location, batch_size=16)
170 | dataloader = get_dataloader_shuffle(dataset)
171 |
172 | for i, data in enumerate(tqdm.tqdm(dataloader)):
173 | data = maybe_dictionarize(data)
174 | x = data['images'].to(args.device)
175 | y = data['labels'].to(args.device)
176 |
177 | outputs = adamerging_mtl_model(x, dataset_name)
178 | loss = softmax_entropy(outputs).mean(0)
179 | losses += loss
180 |
181 | if i > 0:
182 | break
183 |
184 | optimizer.zero_grad()
185 | losses.backward()
186 | optimizer.step()
187 |
188 | print(list(adamerging_mtl_model.lambdas().data))
189 |
190 | if ((epoch+1) % 500) == 0:
191 | log.info(str(list(adamerging_mtl_model.lambdas().data)))
192 |
193 | Total_ACC = 0.
194 | for dataset_name in exam_datasets:
195 | image_encoder = adamerging_mtl_model.get_image_encoder()
196 | classification_head = adamerging_mtl_model.get_classification_head(dataset_name)
197 | metrics = eval_single_dataset_preprocess_head(image_encoder, classification_head, dataset_name, args)
198 | Total_ACC += metrics['top1']
199 | log.info('Eval: Epoch: ' + str(epoch) + ' dataset: ' + str(dataset_name) + ' ACC: ' + str(metrics['top1']))
200 | log.info('Eval: Epoch: ' + str(epoch) + ' Avg ACC:' + str(Total_ACC / len(exam_datasets)) + '\n')
201 |
--------------------------------------------------------------------------------
/src/main_layer_wise_adamergingpp.py:
--------------------------------------------------------------------------------
1 | import os
2 | os.environ["CUDA_VISIBLE_DEVICES"] = "3"
3 |
4 | import time
5 | import sys
6 | import tqdm
7 | sys.path.append('/home/taskarithmetic/')
8 |
9 | from eval import eval_single_dataset, eval_single_dataset_head, eval_single_dataset_preprocess_head
10 | from args import parse_arguments
11 |
12 | def create_log_dir(path, filename='log.txt'):
13 | import logging
14 | if not os.path.exists(path):
15 | os.makedirs(path)
16 | logger = logging.getLogger(path)
17 | logger.setLevel(logging.DEBUG)
18 | fh = logging.FileHandler(path+'/'+filename)
19 | fh.setLevel(logging.DEBUG)
20 | ch = logging.StreamHandler()
21 | ch.setLevel(logging.DEBUG)
22 | logger.addHandler(fh)
23 | logger.addHandler(ch)
24 | return logger
25 |
26 | exam_datasets = ['SUN397', 'Cars', 'RESISC45', 'EuroSAT', 'SVHN', 'GTSRB', 'MNIST', 'DTD'] # SUN397 | Cars | RESISC45 | EuroSAT | SVHN | GTSRB | MNIST | DTD
27 | model = 'ViT-B-32'
28 | args = parse_arguments()
29 | args.data_location = '/home/taskarithmetic/data'
30 | args.model = model
31 | args.save = '/home/taskarithmetic/checkpoints/' + model
32 | args.logs_path = '/home/taskarithmetic/logs/' + model
33 | pretrained_checkpoint = '/home/taskarithmetic/checkpoints/'+model+'/zeroshot.pt'
34 |
35 | str_time_ = time.strftime('%Y%m%d_%H%M%S', time.localtime(time.time()))
36 | log = create_log_dir(args.logs_path, 'log_{}_Layer_wise_AdaMergingPP.txt'.format(str_time_))
37 | args.log = log
38 |
39 | from ties_merging_utils import *
40 |
41 | ft_checks = [torch.load('/home/taskarithmetic/checkpoints/'+model+'/'+dataset_name+'/finetuned.pt').state_dict() for dataset_name in exam_datasets]
42 | ptm_check = torch.load(pretrained_checkpoint).state_dict()
43 |
44 | check_parameterNamesMatch(ft_checks + [ptm_check])
45 |
46 | remove_keys = []
47 | print(f"Flattening out Checkpoints")
48 | flat_ft = torch.vstack([state_dict_to_vector(check, remove_keys) for check in ft_checks])
49 | flat_ptm = state_dict_to_vector(ptm_check, remove_keys)
50 |
51 | tv_flat_checks = flat_ft - flat_ptm
52 |
53 | assert check_state_dicts_equal(vector_to_state_dict(flat_ptm, ptm_check, remove_keys), ptm_check)
54 | assert all([check_state_dicts_equal(vector_to_state_dict(flat_ft[i], ptm_check, remove_keys), ft_checks[i])for i in range(len(ft_checks))])
55 |
56 |
57 | K = 20
58 | merge_func = "dis-sum"
59 |
60 | selected_entries, merged_tv = ties_merging_split(tv_flat_checks, reset_thresh=K, merge_func=merge_func,)
61 |
62 | ties_task_vectors = []
63 | for vector_ in selected_entries:
64 | t_state_dict = vector_to_state_dict(vector_, ptm_check, remove_keys=remove_keys)
65 | ref_model = torch.load(pretrained_checkpoint)
66 | ref_model.load_state_dict(t_state_dict, strict=False)
67 | ties_task_vectors.append(ref_model.state_dict())
68 |
69 | def del_attr(obj, names):
70 | if len(names) == 1:
71 | delattr(obj, names[0])
72 | else:
73 | del_attr(getattr(obj, names[0]), names[1:])
74 |
75 | def set_attr(obj, names, val):
76 | if len(names) == 1:
77 | setattr(obj, names[0], val)
78 | else:
79 | set_attr(getattr(obj, names[0]), names[1:], val)
80 |
81 | def make_functional(mod):
82 | orig_params = tuple(mod.parameters())
83 | names = []
84 | for name, p in list(mod.named_parameters()):
85 | del_attr(mod, name.split("."))
86 | names.append(name)
87 | return orig_params, names
88 |
89 | def load_weights(mod, names, params):
90 | for name, p in zip(names, params):
91 | set_attr(mod, name.split("."), p)
92 |
93 |
94 | class ModelWrapper(torch.nn.Module):
95 | def __init__(self, model, initial_weights=None):
96 | super(ModelWrapper, self).__init__()
97 | self.model = model
98 |
99 | if hasattr(self.model, 'transformer'):
100 | delattr(self.model, 'transformer')
101 |
102 | def forward(self, images):
103 | features = self.model(images)
104 | return features
105 |
106 | from heads import get_classification_head
107 | class AdaMerging(torch.nn.Module):
108 | def __init__(self, paramslist, model, names, exam_datasets):
109 | super(AdaMerging, self).__init__()
110 | self.paramslist = paramslist
111 | self.model = model
112 | self.names = names
113 | self.pretrain_lambdas = torch.ones(len(paramslist[0]), 1)
114 | prior = 0.3
115 | rlambdas = torch.ones(len(paramslist[0]), len(paramslist)-1) * prior # (1 * tasks)
116 | self.lambdas_raw = torch.nn.Parameter(rlambdas)
117 |
118 | self.classifier = []
119 | for dataset_name in exam_datasets:
120 | classification_head = get_classification_head(args, dataset_name)
121 | layer_name = 'classifier_{}'.format(dataset_name)
122 | self.add_module(layer_name, classification_head.to(args.device))
123 | self.classifier.append(layer_name)
124 |
125 | def lambdas(self):
126 | task_lambdas = torch.clamp(self.lambdas_raw, min=0.0, max=1.0)
127 | lambdass = torch.cat((self.pretrain_lambdas, task_lambdas), 1)
128 | return lambdass
129 |
130 | def collect_trainable_params(self):
131 | return [self.lambdas_raw]
132 |
133 | def get_classification_head(self, dataset_name):
134 | layer_name = 'classifier_{}'.format(dataset_name)
135 | classification_head = getattr(self, layer_name)
136 | return classification_head
137 |
138 | def get_image_encoder(self):
139 | alph = self.lambdas()
140 | params = tuple(sum(tuple(pi * lambdasi for pi, lambdasi in zip(p, alph[j].cpu()))) for j, p in enumerate(zip(*self.paramslist)))
141 | params = tuple(p.cuda(0) for p in params)
142 | load_weights(self.model, self.names, params)
143 | return self.model
144 |
145 | def forward(self, inp, dataset_name):
146 | alph = self.lambdas()
147 | params = tuple(sum(tuple(pi * lambdasi for pi, lambdasi in zip(p, alph[j].cpu()))) for j, p in enumerate(zip(*self.paramslist)))
148 |
149 | params = tuple(p.cuda(0) for p in params)
150 | load_weights(self.model, self.names, params)
151 | feature = self.model(inp)
152 |
153 | layer_name = 'classifier_{}'.format(dataset_name)
154 | classification_head = getattr(self, layer_name)
155 | out = classification_head(feature)
156 |
157 | return out
158 |
159 | def softmax_entropy(x):
160 | return -(x.softmax(1) * x.log_softmax(1)).sum(1)
161 |
162 | pretrained_model = torch.load(pretrained_checkpoint)
163 | pretrained_model_dic = pretrained_model.state_dict()
164 |
165 | model = ModelWrapper(pretrained_model, exam_datasets)
166 | model = model.to(args.device)
167 | _, names = make_functional(model)
168 |
169 | paramslist = []
170 | paramslist += [tuple(v.detach().requires_grad_().cpu() for _, v in pretrained_model_dic.items())] # pretrain
171 | paramslist += [tuple(v.detach().requires_grad_().cpu() for _, v in tv.items()) for i, tv in enumerate(ties_task_vectors)] # task vectors
172 | torch.cuda.empty_cache()
173 | adamerging_mtl_model = AdaMerging(paramslist, model, names, exam_datasets)
174 |
175 | print('init lambda:')
176 | print(adamerging_mtl_model.lambdas())
177 | print('collect_trainable_params:')
178 | print(list(adamerging_mtl_model.collect_trainable_params()))
179 |
180 | epochs = 500
181 | optimizer = torch.optim.Adam(adamerging_mtl_model.collect_trainable_params(), lr=1e-3, betas=(0.9, 0.999), weight_decay=0.)
182 |
183 | from datasets.registry import get_dataset
184 | from datasets.common import get_dataloader, maybe_dictionarize, get_dataloader_shuffle
185 |
186 | Total_ACC = 0.
187 | for dataset_name in exam_datasets:
188 | image_encoder = adamerging_mtl_model.get_image_encoder()
189 | classification_head = adamerging_mtl_model.get_classification_head(dataset_name)
190 | metrics = eval_single_dataset_preprocess_head(image_encoder, classification_head, dataset_name, args)
191 | Total_ACC += metrics['top1']
192 | log.info('Eval: init: ' + ' dataset: ' + str(dataset_name) + ' ACC: ' + str(metrics['top1']))
193 | log.info('Eval: init: ' + ' Avg ACC:' + str(Total_ACC / len(exam_datasets)) + '\n')
194 |
195 | for epoch in range(epochs):
196 | losses = 0.
197 | for dataset_name in exam_datasets:
198 | dataset = get_dataset(dataset_name, pretrained_model.val_preprocess, location=args.data_location, batch_size=16)
199 | dataloader = get_dataloader_shuffle(dataset)
200 |
201 | for i, data in enumerate(tqdm.tqdm(dataloader)):
202 | data = maybe_dictionarize(data)
203 | x = data['images'].to(args.device)
204 | y = data['labels'].to(args.device)
205 |
206 | outputs = adamerging_mtl_model(x, dataset_name)
207 | loss = softmax_entropy(outputs).mean(0)
208 | losses += loss
209 |
210 | if i > 0:
211 | break
212 |
213 | optimizer.zero_grad()
214 | losses.backward()
215 | optimizer.step()
216 |
217 | print(list(adamerging_mtl_model.lambdas().data))
218 |
219 | if ((epoch+1) % 500) == 0:
220 | log.info(str(list(adamerging_mtl_model.lambdas().data)))
221 |
222 | Total_ACC = 0.
223 | for dataset_name in exam_datasets:
224 | image_encoder = adamerging_mtl_model.get_image_encoder()
225 | classification_head = adamerging_mtl_model.get_classification_head(dataset_name)
226 | metrics = eval_single_dataset_preprocess_head(image_encoder, classification_head, dataset_name, args)
227 | Total_ACC += metrics['top1']
228 | log.info('Eval: Epoch: ' + str(epoch) + ' dataset: ' + str(dataset_name) + ' ACC: ' + str(metrics['top1']))
229 | log.info('Eval: Epoch: ' + str(epoch) + ' Avg ACC:' + str(Total_ACC / len(exam_datasets)) + '\n')
230 |
--------------------------------------------------------------------------------
/src/main_task_arithmetic.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | import numpy as np
4 |
5 | os.environ["CUDA_VISIBLE_DEVICES"] = "0"
6 |
7 | import time
8 | import sys
9 | sys.path.append('/home/taskarithmetic/')
10 |
11 | from task_vectors import TaskVector
12 | from eval import eval_single_dataset
13 | from args import parse_arguments
14 |
15 | def create_log_dir(path, filename='log.txt'):
16 | import logging
17 | if not os.path.exists(path):
18 | os.makedirs(path)
19 | logger = logging.getLogger(path)
20 | logger.setLevel(logging.DEBUG)
21 | fh = logging.FileHandler(path+'/'+filename)
22 | fh.setLevel(logging.DEBUG)
23 | ch = logging.StreamHandler()
24 | ch.setLevel(logging.DEBUG)
25 | logger.addHandler(fh)
26 | logger.addHandler(ch)
27 | return logger
28 |
29 | exam_datasets = ['SUN397', 'Cars', 'RESISC45', 'EuroSAT', 'SVHN', 'GTSRB', 'MNIST', 'DTD'] # SUN397 | Cars | RESISC45 | EuroSAT | SVHN | GTSRB | MNIST | DTD
30 | model = 'ViT-B-32'
31 | args = parse_arguments()
32 | args.data_location = '/home/taskarithmetic/data'
33 | args.model = model
34 | args.save = '/home/taskarithmetic/checkpoints/' + model
35 | args.logs_path = '/home/taskarithmetic/logs/' + model
36 | pretrained_checkpoint = '/home/taskarithmetic/checkpoints/'+model+'/zeroshot.pt'
37 |
38 | str_time_ = time.strftime('%Y%m%d_%H%M%S', time.localtime(time.time()))
39 | log = create_log_dir(args.logs_path, 'log_{}_task_arithmetic.txt'.format(str_time_))
40 |
41 | task_vectors = [
42 | TaskVector(pretrained_checkpoint, '/home/taskarithmetic/checkpoints/'+model+'/'+dataset_name+'/finetuned.pt') for dataset_name in exam_datasets
43 | ]
44 |
45 | task_vector_sum = sum(task_vectors)
46 |
47 | scaling_coef_ = 0.3
48 |
49 | image_encoder = task_vector_sum.apply_to(pretrained_checkpoint, scaling_coef=scaling_coef_)
50 | log.info('*'*20 + 'scaling_coef:' + str(scaling_coef_) + '*'*20)
51 |
52 | accs = []
53 | for dataset in exam_datasets:
54 | metrics = eval_single_dataset(image_encoder, dataset, args)
55 | log.info(str(dataset) + ':' + str(metrics.get('top1')*100)+'%')
56 | accs.append(metrics.get('top1')*100)
57 | log.info('Avg ACC:' + str(np.mean(accs)) + '%')
58 |
--------------------------------------------------------------------------------
/src/main_task_wise_adamerging.py:
--------------------------------------------------------------------------------
1 | import os
2 | os.environ["CUDA_VISIBLE_DEVICES"] = "0"
3 |
4 | import time
5 | import sys
6 | import tqdm
7 | sys.path.append('/root/taskarithmetic')
8 |
9 | import torch
10 | from task_vectors import TaskVector
11 | from eval import eval_single_dataset, eval_single_dataset_head, eval_single_dataset_preprocess_head
12 | from args import parse_arguments
13 |
14 | def create_log_dir(path, filename='log.txt'):
15 | import logging
16 | if not os.path.exists(path):
17 | os.makedirs(path)
18 | logger = logging.getLogger(path)
19 | logger.setLevel(logging.DEBUG)
20 | fh = logging.FileHandler(path+'/'+filename)
21 | fh.setLevel(logging.DEBUG)
22 | ch = logging.StreamHandler()
23 | ch.setLevel(logging.DEBUG)
24 | logger.addHandler(fh)
25 | logger.addHandler(ch)
26 | return logger
27 |
28 | exam_datasets = ['SUN397', 'Cars', 'RESISC45', 'EuroSAT', 'SVHN', 'GTSRB', 'MNIST', 'DTD'] # SUN397 | Cars | RESISC45 | EuroSAT | SVHN | GTSRB | MNIST | DTD
29 | model = 'ViT-B-32'
30 | source_root_path = '/root'
31 | args = parse_arguments()
32 | args.data_location = source_root_path+'/dataset'
33 | args.model = model
34 | args.save = source_root_path+'/checkpoint/' + model
35 | args.logs_path = '/root/taskarithmetic/src/logs/' + model
36 | pretrained_checkpoint = source_root_path+'/checkpoint/'+model+'/zeroshot.pt'
37 |
38 | str_time_ = time.strftime('%Y%m%d_%H%M%S', time.localtime(time.time()))
39 | log = create_log_dir(args.logs_path, 'log_{}_Task_wise_AdaMerging.txt'.format(str_time_))
40 | args.log = log
41 |
42 | task_vectors = [TaskVector(pretrained_checkpoint, source_root_path+'/checkpoint/'+model+'/'+dataset_name+'/finetuned.pt') for dataset_name in exam_datasets]
43 |
44 | def del_attr(obj, names):
45 | if len(names) == 1:
46 | delattr(obj, names[0])
47 | else:
48 | del_attr(getattr(obj, names[0]), names[1:])
49 |
50 | def set_attr(obj, names, val):
51 | if len(names) == 1:
52 | setattr(obj, names[0], val)
53 | else:
54 | set_attr(getattr(obj, names[0]), names[1:], val)
55 |
56 | def make_functional(mod):
57 | orig_params = tuple(mod.parameters())
58 | names = []
59 | for name, p in list(mod.named_parameters()):
60 | del_attr(mod, name.split("."))
61 | names.append(name)
62 | return orig_params, names
63 |
64 | def load_weights(mod, names, params):
65 | for name, p in zip(names, params):
66 | set_attr(mod, name.split("."), p)
67 |
68 |
69 | class ModelWrapper(torch.nn.Module):
70 | def __init__(self, model, initial_weights=None):
71 | super(ModelWrapper, self).__init__()
72 | self.model = model
73 |
74 | if hasattr(self.model, 'transformer'):
75 | delattr(self.model, 'transformer')
76 |
77 | def forward(self, images):
78 | features = self.model(images)
79 | return features
80 |
81 | from heads import get_classification_head
82 | class AdaMerging(torch.nn.Module):
83 | def __init__(self, paramslist, model, names, exam_datasets):
84 | super(AdaMerging, self).__init__()
85 | self.paramslist = paramslist
86 | self.model = model
87 | self.names = names
88 | self.pretrain_lambdas = torch.ones(1, 1)
89 | prior = 0.3
90 | rlambdas = torch.ones(1, len(paramslist)-1) * prior # (1 * tasks)
91 | self.lambdas_raw = torch.nn.Parameter(rlambdas)
92 |
93 | self.classifier = []
94 | for dataset_name in exam_datasets:
95 | classification_head = get_classification_head(args, dataset_name)
96 | layer_name = 'classifier_{}'.format(dataset_name)
97 | self.add_module(layer_name, classification_head.to(args.device))
98 | self.classifier.append(layer_name)
99 |
100 | def lambdas(self):
101 | task_lambdas = torch.clamp(self.lambdas_raw, min=0.0, max=1.0)
102 | lambdass = torch.cat((self.pretrain_lambdas, task_lambdas), 1)
103 | return lambdass
104 |
105 | def collect_trainable_params(self):
106 | return [self.lambdas_raw]
107 |
108 | def get_classification_head(self, dataset_name):
109 | layer_name = 'classifier_{}'.format(dataset_name)
110 | classification_head = getattr(self, layer_name)
111 | return classification_head
112 |
113 | def get_image_encoder(self):
114 | alph = self.lambdas()
115 | params = tuple(sum(tuple(pi * lambdasi for pi, lambdasi in zip(p, alph[0].cpu()))) for j, p in enumerate(zip(*self.paramslist)))
116 | params = tuple(p.cuda(0) for p in params)
117 | load_weights(self.model, self.names, params)
118 | return self.model
119 |
120 | def forward(self, inp, dataset_name):
121 | alph = self.lambdas()
122 | params = tuple(sum(tuple(pi * lambdasi for pi, lambdasi in zip(p, alph[0].cpu()))) for j, p in enumerate(zip(*self.paramslist)))
123 |
124 | params = tuple(p.cuda(0) for p in params)
125 | load_weights(self.model, self.names, params)
126 | feature = self.model(inp)
127 |
128 | layer_name = 'classifier_{}'.format(dataset_name)
129 | classification_head = getattr(self, layer_name)
130 | out = classification_head(feature)
131 |
132 | return out
133 |
134 | def softmax_entropy(x):
135 | return -(x.softmax(1) * x.log_softmax(1)).sum(1)
136 |
137 | pretrained_model = torch.load(pretrained_checkpoint)
138 | pretrained_model_dic = pretrained_model.state_dict()
139 |
140 | model = ModelWrapper(pretrained_model, exam_datasets)
141 | model = model.to(args.device)
142 | _, names = make_functional(model)
143 |
144 | paramslist = []
145 | paramslist += [tuple(v.detach().requires_grad_().cpu() for _, v in pretrained_model_dic.items())] # pretrain
146 | paramslist += [tuple(v.detach().requires_grad_().cpu() for _, v in tv.vector.items()) for i, tv in enumerate(task_vectors)] # task vectors
147 | torch.cuda.empty_cache()
148 | adamerging_mtl_model = AdaMerging(paramslist, model, names, exam_datasets)
149 |
150 | print('init lambda:')
151 | print(adamerging_mtl_model.lambdas())
152 | print('collect_trainable_params:')
153 | print(list(adamerging_mtl_model.collect_trainable_params()))
154 |
155 | epochs = 500
156 | optimizer = torch.optim.Adam(adamerging_mtl_model.collect_trainable_params(), lr=1e-3, betas=(0.9, 0.999), weight_decay=0.)
157 |
158 | from datasets.registry import get_dataset
159 | from datasets.common import get_dataloader, maybe_dictionarize, get_dataloader_shuffle
160 |
161 | Total_ACC = 0.
162 | for dataset_name in exam_datasets:
163 | image_encoder = adamerging_mtl_model.get_image_encoder()
164 | classification_head = adamerging_mtl_model.get_classification_head(dataset_name)
165 | metrics = eval_single_dataset_preprocess_head(image_encoder, classification_head, dataset_name, args)
166 | Total_ACC += metrics['top1']
167 | log.info('Eval: init: ' + ' dataset: ' + str(dataset_name) + ' ACC: ' + str(metrics['top1']))
168 | log.info('Eval: init: ' + ' Avg ACC:' + str(Total_ACC / len(exam_datasets)) + '\n')
169 |
170 | for epoch in range(epochs):
171 | losses = 0.
172 | for dataset_name in exam_datasets:
173 | dataset = get_dataset(dataset_name, pretrained_model.val_preprocess, location=args.data_location, batch_size=16)
174 | dataloader = get_dataloader_shuffle(dataset)
175 |
176 | for i, data in enumerate(tqdm.tqdm(dataloader)):
177 | data = maybe_dictionarize(data)
178 | x = data['images'].to(args.device)
179 | y = data['labels'].to(args.device)
180 |
181 | outputs = adamerging_mtl_model(x, dataset_name)
182 | loss = softmax_entropy(outputs).mean(0)
183 | losses += loss
184 |
185 | if i > 0:
186 | break
187 |
188 | optimizer.zero_grad()
189 | losses.backward()
190 | optimizer.step()
191 |
192 | if ((epoch+1) % 500) == 0:
193 | log.info(str(list(adamerging_mtl_model.lambdas().data)))
194 |
195 | Total_ACC = 0.
196 | for dataset_name in exam_datasets:
197 | image_encoder = adamerging_mtl_model.get_image_encoder()
198 | classification_head = adamerging_mtl_model.get_classification_head(dataset_name)
199 | metrics = eval_single_dataset_preprocess_head(image_encoder, classification_head, dataset_name, args)
200 | Total_ACC += metrics['top1']
201 | log.info('Eval: Epoch: ' + str(epoch) + ' dataset: ' + str(dataset_name) + ' ACC: ' + str(metrics['top1']))
202 | log.info('Eval: Epoch: ' + str(epoch) + ' Avg ACC:' + str(Total_ACC / len(exam_datasets)) + '\n')
203 |
--------------------------------------------------------------------------------
/src/main_task_wise_adamergingpp.py:
--------------------------------------------------------------------------------
1 | import os
2 | os.environ["CUDA_VISIBLE_DEVICES"] = "0"
3 |
4 | import time
5 | import sys
6 | import tqdm
7 | sys.path.append('/home/taskarithmetic/')
8 |
9 | from eval import eval_single_dataset, eval_single_dataset_head, eval_single_dataset_preprocess_head
10 | from args import parse_arguments
11 |
12 | def create_log_dir(path, filename='log.txt'):
13 | import logging
14 | if not os.path.exists(path):
15 | os.makedirs(path)
16 | logger = logging.getLogger(path)
17 | logger.setLevel(logging.DEBUG)
18 | fh = logging.FileHandler(path+'/'+filename)
19 | fh.setLevel(logging.DEBUG)
20 | ch = logging.StreamHandler()
21 | ch.setLevel(logging.DEBUG)
22 | logger.addHandler(fh)
23 | logger.addHandler(ch)
24 | return logger
25 |
26 | exam_datasets = ['SUN397', 'Cars', 'RESISC45', 'EuroSAT', 'SVHN', 'GTSRB', 'MNIST', 'DTD'] # SUN397 | Cars | RESISC45 | EuroSAT | SVHN | GTSRB | MNIST | DTD
27 | model = 'ViT-B-32'
28 | args = parse_arguments()
29 | args.data_location = '/home/taskarithmetic/data'
30 | args.model = model
31 | args.save = '/home/taskarithmetic/checkpoints/' + model
32 | args.logs_path = '/home/taskarithmetic/logs/' + model
33 | pretrained_checkpoint = '/home/taskarithmetic/checkpoints/'+model+'/zeroshot.pt'
34 |
35 | str_time_ = time.strftime('%Y%m%d_%H%M%S', time.localtime(time.time()))
36 | log = create_log_dir(args.logs_path, 'log_{}_Task_wise_AdaMergingPP.txt'.format(str_time_))
37 | args.log = log
38 |
39 | from ties_merging_utils import *
40 |
41 | ft_checks = [torch.load('/home/taskarithmetic/checkpoints/'+model+'/'+dataset_name+'/finetuned.pt').state_dict() for dataset_name in exam_datasets]
42 | ptm_check = torch.load(pretrained_checkpoint).state_dict()
43 |
44 | check_parameterNamesMatch(ft_checks + [ptm_check])
45 |
46 | remove_keys = []
47 |
48 | print(f"Flattening out Checkpoints")
49 | flat_ft = torch.vstack([state_dict_to_vector(check, remove_keys) for check in ft_checks])
50 | flat_ptm = state_dict_to_vector(ptm_check, remove_keys)
51 |
52 | tv_flat_checks = flat_ft - flat_ptm
53 |
54 | assert check_state_dicts_equal(vector_to_state_dict(flat_ptm, ptm_check, remove_keys), ptm_check)
55 | assert all([check_state_dicts_equal(vector_to_state_dict(flat_ft[i], ptm_check, remove_keys), ft_checks[i])for i in range(len(ft_checks))])
56 |
57 | K = 20
58 | merge_func = "dis-sum"
59 |
60 | selected_entries, merged_tv = ties_merging_split(tv_flat_checks, reset_thresh=K, merge_func=merge_func,)
61 |
62 | ties_task_vectors = []
63 | for vector_ in selected_entries:
64 | t_state_dict = vector_to_state_dict(vector_, ptm_check, remove_keys=remove_keys)
65 | ref_model = torch.load(pretrained_checkpoint)
66 | ref_model.load_state_dict(t_state_dict, strict=False)
67 | ties_task_vectors.append(ref_model.state_dict())
68 |
69 | def del_attr(obj, names):
70 | if len(names) == 1:
71 | delattr(obj, names[0])
72 | else:
73 | del_attr(getattr(obj, names[0]), names[1:])
74 |
75 | def set_attr(obj, names, val):
76 | if len(names) == 1:
77 | setattr(obj, names[0], val)
78 | else:
79 | set_attr(getattr(obj, names[0]), names[1:], val)
80 |
81 | def make_functional(mod):
82 | orig_params = tuple(mod.parameters())
83 | # Remove all the parameters in the model
84 | names = []
85 | for name, p in list(mod.named_parameters()):
86 | del_attr(mod, name.split("."))
87 | names.append(name)
88 | return orig_params, names
89 |
90 | def load_weights(mod, names, params):
91 | for name, p in zip(names, params):
92 | set_attr(mod, name.split("."), p)
93 |
94 |
95 | class ModelWrapper(torch.nn.Module):
96 | def __init__(self, model, initial_weights=None):
97 | super(ModelWrapper, self).__init__()
98 | self.model = model
99 |
100 | if hasattr(self.model, 'transformer'):
101 | delattr(self.model, 'transformer')
102 |
103 | def forward(self, images):
104 | features = self.model(images)
105 | return features
106 |
107 | from heads import get_classification_head
108 | class AdaMerging(torch.nn.Module):
109 | def __init__(self, paramslist, model, names, exam_datasets):
110 | super(AdaMerging, self).__init__()
111 | self.paramslist = paramslist
112 | self.model = model
113 | self.names = names
114 | self.pretrain_lambdas = torch.ones(1, 1)
115 | prior = 0.3
116 | rlambdas = torch.ones(1, len(paramslist)-1) * prior # (1 * tasks)
117 | self.lambdas_raw = torch.nn.Parameter(rlambdas)
118 |
119 | self.classifier = []
120 | for dataset_name in exam_datasets:
121 | classification_head = get_classification_head(args, dataset_name)
122 | layer_name = 'classifier_{}'.format(dataset_name)
123 | self.add_module(layer_name, classification_head.to(args.device))
124 | self.classifier.append(layer_name)
125 |
126 | def lambdas(self):
127 | task_lambdas = torch.clamp(self.lambdas_raw, min=0.0, max=1.0)
128 | lambdass = torch.cat((self.pretrain_lambdas, task_lambdas), 1)
129 | return lambdass
130 |
131 | def collect_trainable_params(self):
132 | return [self.lambdas_raw]
133 |
134 | def get_classification_head(self, dataset_name):
135 | layer_name = 'classifier_{}'.format(dataset_name)
136 | classification_head = getattr(self, layer_name)
137 | return classification_head
138 |
139 | def get_image_encoder(self):
140 | alph = self.lambdas()
141 | params = tuple(sum(tuple(pi * lambdasi for pi, lambdasi in zip(p, alph[0].cpu()))) for j, p in enumerate(zip(*self.paramslist)))
142 | params = tuple(p.cuda(0) for p in params)
143 | load_weights(self.model, self.names, params)
144 | return self.model
145 |
146 | def forward(self, inp, dataset_name):
147 | alph = self.lambdas()
148 | params = tuple(sum(tuple(pi * lambdasi for pi, lambdasi in zip(p, alph[0].cpu()))) for j, p in enumerate(zip(*self.paramslist)))
149 |
150 | params = tuple(p.cuda(0) for p in params)
151 | load_weights(self.model, self.names, params)
152 | feature = self.model(inp)
153 |
154 | layer_name = 'classifier_{}'.format(dataset_name)
155 | classification_head = getattr(self, layer_name)
156 | out = classification_head(feature)
157 |
158 | return out
159 |
160 | def softmax_entropy(x):
161 | return -(x.softmax(1) * x.log_softmax(1)).sum(1)
162 |
163 | pretrained_model = torch.load(pretrained_checkpoint)
164 | pretrained_model_dic = pretrained_model.state_dict()
165 |
166 | model = ModelWrapper(pretrained_model, exam_datasets)
167 | model = model.to(args.device)
168 | _, names = make_functional(model)
169 |
170 | paramslist = []
171 | paramslist += [tuple(v.detach().requires_grad_().cpu() for _, v in pretrained_model_dic.items())] # pretrain
172 | paramslist += [tuple(v.detach().requires_grad_().cpu() for _, v in tv.items()) for i, tv in enumerate(ties_task_vectors)] # task vectors
173 | torch.cuda.empty_cache()
174 | adamerging_mtl_model = AdaMerging(paramslist, model, names, exam_datasets)
175 |
176 | print('init lambda:')
177 | print(adamerging_mtl_model.lambdas())
178 | print('collect_trainable_params:')
179 | print(list(adamerging_mtl_model.collect_trainable_params()))
180 |
181 | epochs = 500
182 | optimizer = torch.optim.Adam(adamerging_mtl_model.collect_trainable_params(), lr=1e-3, betas=(0.9, 0.999), weight_decay=0.)
183 |
184 | from datasets.registry import get_dataset
185 | from datasets.common import get_dataloader, maybe_dictionarize, get_dataloader_shuffle
186 |
187 | Total_ACC = 0.
188 | for dataset_name in exam_datasets:
189 | image_encoder = adamerging_mtl_model.get_image_encoder()
190 | classification_head = adamerging_mtl_model.get_classification_head(dataset_name)
191 | metrics = eval_single_dataset_preprocess_head(image_encoder, classification_head, dataset_name, args)
192 | Total_ACC += metrics['top1']
193 | log.info('Eval: init: ' + ' dataset: ' + str(dataset_name) + ' ACC: ' + str(metrics['top1']))
194 | log.info('Eval: init: ' + ' Avg ACC:' + str(Total_ACC / len(exam_datasets)) + '\n')
195 |
196 | for epoch in range(epochs):
197 | losses = 0.
198 | for dataset_name in exam_datasets:
199 | dataset = get_dataset(dataset_name, pretrained_model.val_preprocess, location=args.data_location, batch_size=16)
200 | dataloader = get_dataloader_shuffle(dataset)
201 |
202 | for i, data in enumerate(tqdm.tqdm(dataloader)):
203 | data = maybe_dictionarize(data)
204 | x = data['images'].to(args.device)
205 | y = data['labels'].to(args.device)
206 |
207 | outputs = adamerging_mtl_model(x, dataset_name)
208 | loss = softmax_entropy(outputs).mean(0)
209 | losses += loss
210 |
211 | if i > 0:
212 | break
213 |
214 | optimizer.zero_grad()
215 | losses.backward()
216 | optimizer.step()
217 |
218 | print(list(adamerging_mtl_model.lambdas().data))
219 |
220 | if ((epoch+1) % 500) == 0:
221 | log.info(str(list(adamerging_mtl_model.lambdas().data)))
222 |
223 | Total_ACC = 0.
224 | for dataset_name in exam_datasets:
225 | image_encoder = adamerging_mtl_model.get_image_encoder()
226 | classification_head = adamerging_mtl_model.get_classification_head(dataset_name)
227 | metrics = eval_single_dataset_preprocess_head(image_encoder, classification_head, dataset_name, args)
228 | Total_ACC += metrics['top1']
229 | log.info('Eval: Epoch: ' + str(epoch) + ' dataset: ' + str(dataset_name) + ' ACC: ' + str(metrics['top1']))
230 | log.info('Eval: Epoch: ' + str(epoch) + ' Avg ACC:' + str(Total_ACC / len(exam_datasets)) + '\n')
231 |
--------------------------------------------------------------------------------
/src/main_ties_merging.py:
--------------------------------------------------------------------------------
1 | import os
2 | os.environ["CUDA_VISIBLE_DEVICES"] = "3"
3 |
4 | import time
5 | import sys
6 | sys.path.append('/home/taskarithmetic/')
7 |
8 | from eval import eval_single_dataset
9 | from args import parse_arguments
10 |
11 | def create_log_dir(path, filename='log.txt'):
12 | import logging
13 | if not os.path.exists(path):
14 | os.makedirs(path)
15 | logger = logging.getLogger(path)
16 | logger.setLevel(logging.DEBUG)
17 | fh = logging.FileHandler(path+'/'+filename)
18 | fh.setLevel(logging.DEBUG)
19 | ch = logging.StreamHandler()
20 | ch.setLevel(logging.DEBUG)
21 | logger.addHandler(fh)
22 | logger.addHandler(ch)
23 | return logger
24 |
25 | exam_datasets = ['SUN397', 'Cars', 'RESISC45', 'EuroSAT', 'SVHN', 'GTSRB', 'MNIST', 'DTD'] # SUN397 | Cars | RESISC45 | EuroSAT | SVHN | GTSRB | MNIST | DTD
26 | model = 'ViT-B-32'
27 | args = parse_arguments()
28 | args.data_location = '/home/taskarithmetic/data'
29 | args.model = model
30 | args.save = '/home/taskarithmetic/checkpoints/' + model
31 | args.logs_path = '/home/taskarithmetic/logs/' + model
32 | pretrained_checkpoint = '/home/taskarithmetic/checkpoints/'+model+'/zeroshot.pt'
33 |
34 | str_time_ = time.strftime('%Y%m%d_%H%M%S', time.localtime(time.time()))
35 | log = create_log_dir(args.logs_path, 'log_{}_ties_merging.txt'.format(str_time_))
36 |
37 | from ties_merging_utils import *
38 | ft_checks = [torch.load('/home/taskarithmetic/checkpoints/'+model+'/'+dataset_name+'/finetuned.pt').state_dict() for dataset_name in exam_datasets]
39 | ptm_check = torch.load(pretrained_checkpoint).state_dict()
40 | check_parameterNamesMatch(ft_checks + [ptm_check])
41 |
42 | remove_keys = []
43 | print(f"Flattening out Checkpoints")
44 | flat_ft = torch.vstack([state_dict_to_vector(check, remove_keys) for check in ft_checks])
45 | flat_ptm = state_dict_to_vector(ptm_check, remove_keys)
46 |
47 | tv_flat_checks = flat_ft - flat_ptm
48 | assert check_state_dicts_equal(vector_to_state_dict(flat_ptm, ptm_check, remove_keys), ptm_check)
49 | assert all([check_state_dicts_equal(vector_to_state_dict(flat_ft[i], ptm_check, remove_keys), ft_checks[i])for i in range(len(ft_checks))])
50 |
51 |
52 | K = 20
53 | merge_func = "dis-sum"
54 | scaling_coef_ = 0.3
55 |
56 | merged_tv = ties_merging(tv_flat_checks, reset_thresh=K, merge_func=merge_func,)
57 | merged_check = flat_ptm + scaling_coef_ * merged_tv
58 | merged_state_dict = vector_to_state_dict(merged_check, ptm_check, remove_keys=remove_keys)
59 |
60 | image_encoder = torch.load(pretrained_checkpoint)
61 | image_encoder.load_state_dict(merged_state_dict, strict=False)
62 |
63 | Total_ACC = 0.
64 | for dataset in exam_datasets:
65 | metrics = eval_single_dataset(image_encoder, dataset, args)
66 | Total_ACC += metrics['top1']
67 | log.info(str(dataset) + ':' + str(metrics))
68 |
69 | log.info('Final: ' + 'Avg ACC:' + str(Total_ACC / len(exam_datasets)))
70 |
--------------------------------------------------------------------------------
/src/modeling.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | import open_clip
4 |
5 | import utils
6 |
7 |
8 | class ImageEncoder(torch.nn.Module):
9 | def __init__(self, args, keep_lang=False):
10 | super().__init__()
11 |
12 | print(f'Loading {args.model} pre-trained weights.')
13 | if '__pretrained__' in args.model:
14 | name, pretrained = args.model.split('__pretrained__')
15 | else:
16 | name = args.model
17 | pretrained = 'openai'
18 | self.model, self.train_preprocess, self.val_preprocess = open_clip.create_model_and_transforms(
19 | name, pretrained=pretrained, cache_dir=args.openclip_cachedir)
20 |
21 | self.cache_dir = args.cache_dir
22 |
23 | if not keep_lang and hasattr(self.model, 'transformer'):
24 | delattr(self.model, 'transformer')
25 |
26 | def forward(self, images):
27 | assert self.model is not None
28 | return self.model.encode_image(images)
29 |
30 | def __call__(self, inputs):
31 | return self.forward(inputs)
32 |
33 | def save(self, filename):
34 | print(f'Saving image encoder to {filename}')
35 | utils.torch_save(self, filename)
36 |
37 | @classmethod
38 | def load(cls, model_name, filename):
39 | print(f'Loading image encoder from {filename}')
40 | state_dict = torch.load(filename)
41 | return cls.load(model_name, state_dict)
42 |
43 | @classmethod
44 | def load_from_state_dict(cls, model_name, state_dict):
45 | self.model, self.train_preprocess, self.val_preprocess = open_clip.create_model_and_transforms(
46 | name, pretrained=pretrained, cache_dir=args.openclip_cachedir)
47 | self.model.load_from_state_dict(state_dict)
48 |
49 |
50 |
51 |
52 | class ClassificationHead(torch.nn.Linear):
53 | def __init__(self, normalize, weights, biases=None):
54 | output_size, input_size = weights.shape
55 | super().__init__(input_size, output_size)
56 | self.normalize = normalize
57 | if weights is not None:
58 | self.weight = torch.nn.Parameter(weights.clone())
59 | if biases is not None:
60 | self.bias = torch.nn.Parameter(biases.clone())
61 | else:
62 | self.bias = torch.nn.Parameter(torch.zeros_like(self.bias))
63 |
64 | def forward(self, inputs):
65 | if self.normalize:
66 | inputs = inputs / inputs.norm(dim=-1, keepdim=True)
67 | return super().forward(inputs)
68 |
69 | def __call__(self, inputs):
70 | return self.forward(inputs)
71 |
72 | def save(self, filename):
73 | print(f'Saving classification head to {filename}')
74 | utils.torch_save(self, filename)
75 |
76 | @classmethod
77 | def load(cls, filename):
78 | print(f'Loading classification head from {filename}')
79 | return utils.torch_load(filename)
80 |
81 |
82 | class ImageClassifier(torch.nn.Module):
83 | def __init__(self, image_encoder, classification_head):
84 | super().__init__()
85 | self.image_encoder = image_encoder
86 | self.classification_head = classification_head
87 | if self.image_encoder is not None:
88 | if hasattr(self.image_encoder, 'train_preprocess'):
89 | self.train_preprocess = self.image_encoder.train_preprocess
90 | self.val_preprocess = self.image_encoder.val_preprocess
91 | elif hasattr(self.image_encoder.model, 'train_preprocess'):
92 | self.train_preprocess = self.image_encoder.model.train_preprocess
93 | self.val_preprocess = self.image_encoder.model.val_preprocess
94 |
95 | def freeze_head(self):
96 | self.classification_head.weight.requires_grad_(False)
97 | self.classification_head.bias.requires_grad_(False)
98 |
99 | def forward(self, inputs):
100 | features = self.image_encoder(inputs)
101 | outputs = self.classification_head(features)
102 | return outputs
103 |
104 | def __call__(self, inputs):
105 | return self.forward(inputs)
106 |
107 | def save(self, filename):
108 | print(f'Saving image classifier to {filename}')
109 | utils.torch_save(self, filename)
110 |
111 | @classmethod
112 | def load(cls, filename):
113 | print(f'Loading image classifier from {filename}')
114 | return utils.torch_load(filename)
115 |
116 | class ImageClassifier_debug(torch.nn.Module):
117 | def __init__(self, image_encoder, image_encoder2, classification_head):
118 | super().__init__()
119 | self.image_encoder = image_encoder
120 | self.image_encoder2 = image_encoder2
121 | self.classification_head = classification_head
122 | if self.image_encoder is not None:
123 | self.train_preprocess = self.image_encoder.train_preprocess
124 | self.val_preprocess = self.image_encoder.val_preprocess
125 |
126 | def freeze_head(self):
127 | self.classification_head.weight.requires_grad_(False)
128 | self.classification_head.bias.requires_grad_(False)
129 |
130 | def forward(self, inputs):
131 | features = self.image_encoder(inputs)
132 | features2 = self.image_encoder2(inputs)
133 | outputs = self.classification_head(features + features2)
134 | return outputs
135 |
136 | def __call__(self, inputs):
137 | return self.forward(inputs)
138 |
139 | def save(self, filename):
140 | print(f'Saving image classifier to {filename}')
141 | utils.torch_save(self, filename)
142 |
143 | @classmethod
144 | def load(cls, filename):
145 | print(f'Loading image classifier from {filename}')
146 | return utils.torch_load(filename)
147 |
148 | class MultiHeadImageClassifier(torch.nn.Module):
149 | def __init__(self, image_encoder, classification_heads):
150 | super().__init__()
151 | self.image_encoder = image_encoder
152 | self.classification_heads = torch.nn.ModuleList(classification_heads)
153 | if self.image_encoder is not None:
154 | self.train_preprocess = self.image_encoder.train_preprocess
155 | self.val_preprocess = self.image_encoder.val_preprocess
156 |
157 | def freeze_head(self):
158 | for idx in range(len(self.classification_heads)):
159 | self.classification_heads[idx].weight.requires_grad_(False)
160 | self.classification_heads[idx].bias.requires_grad_(False)
161 |
162 | def forward(self, inputs, head_idx):
163 | features = self.image_encoder(inputs)
164 | outputs = self.classification_heads[head_idx](features)
165 | return outputs
166 |
167 | def __call__(self, inputs, head_idx):
168 | return self.forward(inputs, head_idx)
169 |
170 | def save(self, filename):
171 | print(f'Saving image classifier to {filename}')
172 | utils.torch_save(self, filename)
173 |
174 | @classmethod
175 | def load(cls, filename):
176 | print(f'Loading image classifier from {filename}')
177 | return utils.torch_load(filename)
178 |
--------------------------------------------------------------------------------
/src/task_vectors.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 |
4 | class TaskVector():
5 | def __init__(self, pretrained_checkpoint=None, finetuned_checkpoint=None, vector=None):
6 | """Initializes the task vector from a pretrained and a finetuned checkpoints.
7 |
8 | This can either be done by passing two state dicts (one corresponding to the
9 | pretrained model, and another to the finetuned model), or by directly passying in
10 | the task vector state dict.
11 | """
12 | if vector is not None:
13 | self.vector = vector
14 | else:
15 | assert pretrained_checkpoint is not None and finetuned_checkpoint is not None
16 | with torch.no_grad():
17 | print('TaskVector:' + finetuned_checkpoint)
18 | pretrained_state_dict = torch.load(pretrained_checkpoint).state_dict()
19 | finetuned_state_dict = torch.load(finetuned_checkpoint).state_dict()
20 | self.vector = {}
21 | for key in pretrained_state_dict:
22 | if pretrained_state_dict[key].dtype in [torch.int64, torch.uint8]:
23 | continue
24 | self.vector[key] = finetuned_state_dict[key] - pretrained_state_dict[key]
25 |
26 | def __add__(self, other):
27 | """Add two task vectors together."""
28 | with torch.no_grad():
29 | new_vector = {}
30 | for key in self.vector:
31 | if key not in other.vector:
32 | print(f'Warning, key {key} is not present in both task vectors.')
33 | continue
34 | new_vector[key] = self.vector[key] + other.vector[key]
35 | return TaskVector(vector=new_vector)
36 |
37 | def __radd__(self, other):
38 | if other is None or isinstance(other, int):
39 | return self
40 | return self.__add__(other)
41 |
42 | def __neg__(self):
43 | """Negate a task vector."""
44 | with torch.no_grad():
45 | new_vector = {}
46 | for key in self.vector:
47 | new_vector[key] = - self.vector[key]
48 | return TaskVector(vector=new_vector)
49 |
50 | def weightmerging(self, taskvectors, coefficients):
51 | with torch.no_grad():
52 | new_vector = {}
53 | for key in taskvectors[0].vector:
54 | new_vector[key] = sum(coefficients[k] * taskvectors[k][key] for k in range(len(taskvectors)))
55 | return TaskVector(vector=new_vector)
56 |
57 | def apply_to(self, pretrained_checkpoint, scaling_coef=1.0):
58 | """Apply a task vector to a pretrained model."""
59 | with torch.no_grad():
60 | pretrained_model = torch.load(pretrained_checkpoint)
61 | new_state_dict = {}
62 | pretrained_state_dict = pretrained_model.state_dict()
63 | for key in pretrained_state_dict:
64 | if key not in self.vector:
65 | print(f'Warning: key {key} is present in the pretrained state dict but not in the task vector')
66 | continue
67 | new_state_dict[key] = pretrained_state_dict[key] + scaling_coef * self.vector[key]
68 | pretrained_model.load_state_dict(new_state_dict, strict=False)
69 | return pretrained_model
70 |
71 |
--------------------------------------------------------------------------------
/src/ties_merging_utils.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import os, copy
3 | import torch
4 | import matplotlib.pyplot as plt
5 | import numpy as np
6 | import re
7 | from collections import OrderedDict
8 | import torch.nn.functional as F
9 | # from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
10 |
11 | ## Model conversion utils
12 | def state_dict_to_vector(state_dict, remove_keys=[]):
13 | shared_state_dict = copy.deepcopy(state_dict)
14 | for key in remove_keys:
15 | if key in shared_state_dict:
16 | del shared_state_dict[key]
17 | sorted_shared_state_dict = OrderedDict(sorted(shared_state_dict.items()))
18 | return torch.nn.utils.parameters_to_vector(
19 | [value.reshape(-1) for key, value in sorted_shared_state_dict.items()]
20 | )
21 |
22 |
23 | def vector_to_state_dict(vector, state_dict, remove_keys=[]):
24 | # create a reference dict to define the order of the vector
25 | reference_dict = copy.deepcopy(state_dict)
26 | for key in remove_keys:
27 | if key in reference_dict:
28 | del reference_dict[key]
29 | sorted_reference_dict = OrderedDict(sorted(reference_dict.items()))
30 |
31 | # create a shared state dict using the refence dict
32 | torch.nn.utils.vector_to_parameters(vector, sorted_reference_dict.values())
33 |
34 | # add back the encoder and decoder embedding weights.
35 | if "transformer.shared.weight" in sorted_reference_dict:
36 | for key in remove_keys:
37 | sorted_reference_dict[key] = sorted_reference_dict[
38 | "transformer.shared.weight"
39 | ]
40 | return sorted_reference_dict
41 |
42 |
43 | def add_ptm_to_tv(tv_dict, ptm_dict):
44 | assert set(tv_dict.keys()) == set(
45 | ptm_dict.keys()
46 | ), "Differing parameter names in models."
47 | final_dict = copy.deepcopy(tv_dict)
48 | for k, v in ptm_dict.items():
49 | final_dict[k] = tv_dict[k] + v
50 | return final_dict
51 |
52 |
53 | def check_parameterNamesMatch(checkpoints):
54 | parameter_names = set(checkpoints[0].keys())
55 |
56 | if len(checkpoints) >= 2:
57 | # raise ValueError("Number of models is less than 2.")
58 | for checkpoint in checkpoints[1:]:
59 | current_parameterNames = set(checkpoint.keys())
60 | if current_parameterNames != parameter_names:
61 | raise ValueError(
62 | "Differing parameter names in models. "
63 | f"The different parameters are {parameter_names.symmetric_difference(current_parameterNames)}"
64 | )
65 |
66 | def check_state_dicts_equal(state_dict1, state_dict2):
67 | if set(state_dict1.keys()) != set(state_dict2.keys()):
68 | return False
69 |
70 | for key in state_dict1.keys():
71 | if not torch.equal(state_dict1[key], state_dict2[key]):
72 | return False
73 |
74 | return True
75 |
76 |
77 |
78 | ## TIES MERGING UTILS
79 |
80 | def topk_values_mask(M, K=0.7, return_mask=False):
81 | if K > 1:
82 | K /= 100
83 |
84 | original_shape = M.shape
85 | if M.dim() == 1:
86 | M = M.unsqueeze(0)
87 |
88 | n, d = M.shape
89 | k = int(d * K)
90 | k = d - k # Keep top k elements instead of bottom k elements
91 |
92 | # Find the k-th smallest element by magnitude for each row
93 | kth_values, _ = M.abs().kthvalue(k, dim=1, keepdim=True)
94 | # Create a mask tensor with True for the top k elements in each row
95 | mask = M.abs() >= kth_values
96 | final_mask = mask.squeeze() if original_shape == M.squeeze().shape else mask
97 |
98 | if return_mask:
99 | return M * final_mask, final_mask.float().mean(dim=1), final_mask
100 | return M * final_mask, final_mask.float().mean(dim=1)
101 |
102 |
103 | def resolve_zero_signs(sign_to_mult, method="majority"):
104 | majority_sign = torch.sign(sign_to_mult.sum())
105 |
106 | if method == "majority":
107 | sign_to_mult[sign_to_mult == 0] = majority_sign
108 | elif method == "minority":
109 | sign_to_mult[sign_to_mult == 0] = -1 * majority_sign
110 | return sign_to_mult
111 |
112 |
113 | def resolve_sign(Tensor):
114 | sign_to_mult = torch.sign(Tensor.sum(dim=0))
115 | sign_to_mult = resolve_zero_signs(sign_to_mult, "majority")
116 | return sign_to_mult
117 |
118 |
119 | def disjoint_merge(Tensor, merge_func, sign_to_mult):
120 | merge_func = merge_func.split("-")[-1]
121 |
122 | # If sign is provided then we select the corresponding entries and aggregate.
123 | if sign_to_mult is not None:
124 | rows_to_keep = torch.where(
125 | sign_to_mult.unsqueeze(0) > 0, Tensor > 0, Tensor < 0
126 | )
127 | selected_entries = Tensor * rows_to_keep
128 | # Else we select all non-zero entries and aggregate.
129 | else:
130 | rows_to_keep = Tensor != 0
131 | selected_entries = Tensor * rows_to_keep
132 |
133 | if merge_func == "mean":
134 | non_zero_counts = (selected_entries != 0).sum(dim=0).float()
135 | disjoint_aggs = torch.sum(selected_entries, dim=0) / torch.clamp(non_zero_counts, min=1)
136 | elif merge_func == "sum":
137 | disjoint_aggs = torch.sum(selected_entries, dim=0)
138 | elif merge_func == "max":
139 | disjoint_aggs = selected_entries.abs().max(dim=0)[0]
140 | disjoint_aggs *= sign_to_mult
141 | else:
142 | raise ValueError(f"Merge method {merge_func} is not defined.")
143 |
144 | return disjoint_aggs
145 |
146 |
147 | def ties_merging(
148 | flat_task_checks,
149 | reset_thresh=None,
150 | merge_func="",
151 | ):
152 | all_checks = flat_task_checks.clone()
153 | updated_checks, *_ = topk_values_mask(
154 | all_checks, K=reset_thresh, return_mask=False
155 | )
156 | print(f"RESOLVING SIGN")
157 | final_signs = resolve_sign(updated_checks)
158 | assert final_signs is not None
159 |
160 | print(f"Disjoint AGGREGATION: {merge_func}")
161 | merged_tv = disjoint_merge(updated_checks, merge_func, final_signs)
162 |
163 | return merged_tv
164 |
165 | def disjoint_merge_split(Tensor, merge_func, sign_to_mult):
166 | merge_func = merge_func.split("-")[-1]
167 |
168 | # If sign is provided then we select the corresponding entries and aggregate.
169 | if sign_to_mult is not None:
170 | rows_to_keep = torch.where(
171 | sign_to_mult.unsqueeze(0) > 0, Tensor > 0, Tensor < 0
172 | )
173 | selected_entries = Tensor * rows_to_keep
174 | # Else we select all non-zero entries and aggregate.
175 | else:
176 | rows_to_keep = Tensor != 0
177 | selected_entries = Tensor * rows_to_keep
178 |
179 | if merge_func == "sum":
180 | disjoint_aggs = torch.sum(selected_entries, dim=0)
181 | else:
182 | raise ValueError(f"Merge method {merge_func} is not defined.")
183 |
184 | return selected_entries, disjoint_aggs
185 |
186 |
187 | def ties_merging_split(
188 | flat_task_checks,
189 | reset_thresh=None,
190 | merge_func="",
191 | ):
192 | all_checks = flat_task_checks.clone()
193 | updated_checks, *_ = topk_values_mask(
194 | all_checks, K=reset_thresh, return_mask=False
195 | )
196 | print(f"RESOLVING SIGN")
197 | final_signs = resolve_sign(updated_checks)
198 | assert final_signs is not None
199 |
200 | print(f"Disjoint AGGREGATION: {merge_func}")
201 | selected_entries, merged_tv = disjoint_merge_split(updated_checks, merge_func, final_signs)
202 |
203 | return selected_entries, merged_tv
204 |
--------------------------------------------------------------------------------
/src/utils.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | import torch
4 | import pickle
5 | from tqdm import tqdm
6 | import math
7 |
8 | import numpy as np
9 |
10 |
11 | def assign_learning_rate(param_group, new_lr):
12 | param_group["lr"] = new_lr
13 |
14 |
15 | def _warmup_lr(base_lr, warmup_length, step):
16 | return base_lr * (step + 1) / warmup_length
17 |
18 |
19 | def cosine_lr(optimizer, base_lrs, warmup_length, steps):
20 | if not isinstance(base_lrs, list):
21 | base_lrs = [base_lrs for _ in optimizer.param_groups]
22 | assert len(base_lrs) == len(optimizer.param_groups)
23 | def _lr_adjuster(step):
24 | for param_group, base_lr in zip(optimizer.param_groups, base_lrs):
25 | if step < warmup_length:
26 | lr = _warmup_lr(base_lr, warmup_length, step)
27 | else:
28 | e = step - warmup_length
29 | es = steps - warmup_length
30 | lr = 0.5 * (1 + np.cos(np.pi * e / es)) * base_lr
31 | assign_learning_rate(param_group, lr)
32 | return _lr_adjuster
33 |
34 |
35 | def accuracy(output, target, topk=(1,)):
36 | pred = output.topk(max(topk), 1, True, True)[1].t()
37 | correct = pred.eq(target.view(1, -1).expand_as(pred))
38 | return [float(correct[:k].reshape(-1).float().sum(0, keepdim=True).cpu().numpy()) for k in topk]
39 |
40 |
41 | def torch_load_old(save_path, device=None):
42 | with open(save_path, 'rb') as f:
43 | classifier = pickle.load(f)
44 | if device is not None:
45 | classifier = classifier.to(device)
46 | return classifier
47 |
48 |
49 | def torch_save(model, save_path):
50 | if os.path.dirname(save_path) != '':
51 | os.makedirs(os.path.dirname(save_path), exist_ok=True)
52 | torch.save(model.cpu(), save_path)
53 |
54 |
55 | def torch_load(save_path, device=None):
56 | model = torch.load(save_path)
57 | if device is not None:
58 | model = model.to(device)
59 | return model
60 |
61 |
62 |
63 | def get_logits(inputs, classifier):
64 | assert callable(classifier)
65 | if hasattr(classifier, 'to'):
66 | classifier = classifier.to(inputs.device)
67 | return classifier(inputs)
68 |
69 |
70 | def get_probs(inputs, classifier):
71 | if hasattr(classifier, 'predict_proba'):
72 | probs = classifier.predict_proba(inputs.detach().cpu().numpy())
73 | return torch.from_numpy(probs)
74 | logits = get_logits(inputs, classifier)
75 | return logits.softmax(dim=1)
76 |
77 |
78 | class LabelSmoothing(torch.nn.Module):
79 | def __init__(self, smoothing=0.0):
80 | super(LabelSmoothing, self).__init__()
81 | self.confidence = 1.0 - smoothing
82 | self.smoothing = smoothing
83 |
84 | def forward(self, x, target):
85 | logprobs = torch.nn.functional.log_softmax(x, dim=-1)
86 |
87 | nll_loss = -logprobs.gather(dim=-1, index=target.unsqueeze(1))
88 | nll_loss = nll_loss.squeeze(1)
89 | smooth_loss = -logprobs.mean(dim=-1)
90 | loss = self.confidence * nll_loss + self.smoothing * smooth_loss
91 | return loss.mean()
92 |
--------------------------------------------------------------------------------