├── .gitignore ├── LEEP.py ├── LICENSE.txt ├── LogME.py ├── NCE.py ├── README.md ├── b_tuning.py ├── imgs ├── image-20210222204141915.png ├── image-20210222204350389.png └── image-20210222204712553.png ├── ranking.py └── utils ├── tools.py └── transforms.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | -------------------------------------------------------------------------------- /LEEP.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | def LEEP(pseudo_source_label: np.ndarray, target_label: np.ndarray): 5 | """ 6 | 7 | :param pseudo_source_label: shape [N, C_s] 8 | :param target_label: shape [N], elements in [0, C_t) 9 | :return: leep score 10 | """ 11 | N, C_s = pseudo_source_label.shape 12 | target_label = target_label.reshape(-1) 13 | C_t = int(np.max(target_label) + 1) # the number of target classes 14 | normalized_prob = pseudo_source_label / float(N) # sum(normalized_prob) = 1 15 | joint = np.zeros((C_t, C_s), dtype=float) # placeholder for joint distribution over (y, z) 16 | for i in range(C_t): 17 | this_class = normalized_prob[target_label == i] 18 | row = np.sum(this_class, axis=0) 19 | joint[i] = row 20 | p_target_given_source = (joint / joint.sum(axis=0, keepdims=True)).T # P(y | z) 21 | 22 | empirical_prediction = pseudo_source_label @ p_target_given_source 23 | empirical_prob = np.array([predict[label] for predict, label in zip(empirical_prediction, target_label)]) 24 | leep_score = np.mean(np.log(empirical_prob)) 25 | return leep_score 26 | -------------------------------------------------------------------------------- /LICENSE.txt: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 youkaichao 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /LogME.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | 3 | import numpy as np 4 | from numba import njit 5 | 6 | 7 | @njit 8 | def each_evidence(y_, f, fh, v, s, vh, N, D): 9 | """ 10 | compute the maximum evidence for each class 11 | """ 12 | epsilon = 1e-5 13 | alpha = 1.0 14 | beta = 1.0 15 | lam = alpha / beta 16 | tmp = (vh @ (f @ np.ascontiguousarray(y_))) 17 | for _ in range(11): 18 | # should converge after at most 10 steps 19 | # typically converge after two or three steps 20 | gamma = (s / (s + lam)).sum() 21 | # A = v @ np.diag(alpha + beta * s) @ v.transpose() # no need to compute A 22 | # A_inv = v @ np.diag(1.0 / (alpha + beta * s)) @ v.transpose() # no need to compute A_inv 23 | m = v @ (tmp * beta / (alpha + beta * s)) 24 | alpha_de = (m * m).sum() 25 | alpha = gamma / (alpha_de + epsilon) 26 | beta_de = ((y_ - fh @ m) ** 2).sum() 27 | beta = (N - gamma) / (beta_de + epsilon) 28 | new_lam = alpha / beta 29 | if np.abs(new_lam - lam) / lam < 0.01: 30 | break 31 | lam = new_lam 32 | evidence = D / 2.0 * np.log(alpha) \ 33 | + N / 2.0 * np.log(beta) \ 34 | - 0.5 * np.sum(np.log(alpha + beta * s)) \ 35 | - beta / 2.0 * (beta_de + epsilon) \ 36 | - alpha / 2.0 * (alpha_de + epsilon) \ 37 | - N / 2.0 * np.log(2 * np.pi) 38 | return evidence / N, alpha, beta, m 39 | 40 | 41 | # use pseudo data to compile the function 42 | # D = 20, N = 50 43 | f_tmp = np.random.randn(20, 50).astype(np.float64) 44 | each_evidence(np.random.randint(0, 2, 50).astype(np.float64), f_tmp, f_tmp.transpose(), np.eye(20, dtype=np.float64), np.ones(20, dtype=np.float64), np.eye(20, dtype=np.float64), 50, 20) 45 | 46 | 47 | @njit 48 | def truncated_svd(x): 49 | u, s, vh = np.linalg.svd(x.transpose() @ x) 50 | s = np.sqrt(s) 51 | u_times_sigma = x @ vh.transpose() 52 | k = np.sum((s > 1e-10) * 1) # rank of f 53 | s = s.reshape(-1, 1) 54 | s = s[:k] 55 | vh = vh[:k] 56 | u = u_times_sigma[:, :k] / s.reshape(1, -1) 57 | return u, s, vh 58 | truncated_svd(np.random.randn(20, 10).astype(np.float64)) 59 | 60 | 61 | class LogME(object): 62 | def __init__(self, regression=False): 63 | """ 64 | :param regression: whether regression 65 | """ 66 | self.regression = regression 67 | self.fitted = False 68 | self.reset() 69 | 70 | def reset(self): 71 | self.num_dim = 0 72 | self.alphas = [] # alpha for each class / dimension 73 | self.betas = [] # beta for each class / dimension 74 | # self.ms.shape --> [C, D] 75 | self.ms = [] # m for each class / dimension 76 | 77 | def _fit_icml(self, f: np.ndarray, y: np.ndarray): 78 | """ 79 | LogME calculation proposed in the ICML 2021 paper 80 | "LogME: Practical Assessment of Pre-trained Models for Transfer Learning" 81 | at http://proceedings.mlr.press/v139/you21b.html 82 | """ 83 | fh = f 84 | f = f.transpose() 85 | D, N = f.shape 86 | v, s, vh = np.linalg.svd(f @ fh, full_matrices=True) 87 | 88 | evidences = [] 89 | self.num_dim = y.shape[1] if self.regression else int(y.max() + 1) 90 | for i in range(self.num_dim): 91 | y_ = y[:, i] if self.regression else (y == i).astype(np.float64) 92 | evidence, alpha, beta, m = each_evidence(y_, f, fh, v, s, vh, N, D) 93 | evidences.append(evidence) 94 | self.alphas.append(alpha) 95 | self.betas.append(beta) 96 | self.ms.append(m) 97 | self.ms = np.stack(self.ms) 98 | return np.mean(evidences) 99 | 100 | def _fit_fixed_point(self, f: np.ndarray, y: np.ndarray): 101 | """ 102 | LogME calculation proposed in the arxiv 2021 paper 103 | "Ranking and Tuning Pre-trained Models: A New Paradigm of Exploiting Model Hubs" 104 | at https://arxiv.org/abs/2110.10545 105 | """ 106 | N, D = f.shape # k = min(N, D) 107 | if N > D: # direct SVD may be expensive 108 | u, s, vh = truncated_svd(f) 109 | else: 110 | u, s, vh = np.linalg.svd(f, full_matrices=False) 111 | # u.shape = N x k 112 | # s.shape = k 113 | # vh.shape = k x D 114 | s = s.reshape(-1, 1) 115 | sigma = (s ** 2) 116 | 117 | evidences = [] 118 | self.num_dim = y.shape[1] if self.regression else int(y.max() + 1) 119 | for i in range(self.num_dim): 120 | y_ = y[:, i] if self.regression else (y == i).astype(np.float64) 121 | y_ = y_.reshape(-1, 1) 122 | x = u.T @ y_ # x has shape [k, 1], but actually x should have shape [N, 1] 123 | x2 = x ** 2 124 | res_x2 = (y_ ** 2).sum() - x2.sum() # if k < N, we compute sum of xi for 0 singular values directly 125 | 126 | alpha, beta = 1.0, 1.0 127 | for _ in range(11): 128 | t = alpha / beta 129 | gamma = (sigma / (sigma + t)).sum() 130 | m2 = (sigma * x2 / ((t + sigma) ** 2)).sum() 131 | res2 = (x2 / ((1 + sigma / t) ** 2)).sum() + res_x2 132 | alpha = gamma / (m2 + 1e-5) 133 | beta = (N - gamma) / (res2 + 1e-5) 134 | t_ = alpha / beta 135 | evidence = D / 2.0 * np.log(alpha) \ 136 | + N / 2.0 * np.log(beta) \ 137 | - 0.5 * np.sum(np.log(alpha + beta * sigma)) \ 138 | - beta / 2.0 * res2 \ 139 | - alpha / 2.0 * m2 \ 140 | - N / 2.0 * np.log(2 * np.pi) 141 | evidence /= N 142 | if abs(t_ - t) / t <= 1e-3: # abs(t_ - t) <= 1e-5 or abs(1 / t_ - 1 / t) <= 1e-5: 143 | break 144 | evidence = D / 2.0 * np.log(alpha) \ 145 | + N / 2.0 * np.log(beta) \ 146 | - 0.5 * np.sum(np.log(alpha + beta * sigma)) \ 147 | - beta / 2.0 * res2 \ 148 | - alpha / 2.0 * m2 \ 149 | - N / 2.0 * np.log(2 * np.pi) 150 | evidence /= N 151 | m = 1.0 / (t + sigma) * s * x 152 | m = (vh.T @ m).reshape(-1) 153 | evidences.append(evidence) 154 | self.alphas.append(alpha) 155 | self.betas.append(beta) 156 | self.ms.append(m) 157 | self.ms = np.stack(self.ms) 158 | return np.mean(evidences) 159 | 160 | _fit = _fit_fixed_point 161 | 162 | def fit(self, f: np.ndarray, y: np.ndarray): 163 | """ 164 | :param f: [N, F], feature matrix from pre-trained model 165 | :param y: target labels. 166 | For classification, y has shape [N] with element in [0, C_t). 167 | For regression, y has shape [N, C] with C regression-labels 168 | 169 | :return: LogME score (how well f can fit y directly) 170 | """ 171 | if self.fitted: 172 | warnings.warn('re-fitting for new data. old parameters cleared.') 173 | self.reset() 174 | else: 175 | self.fitted = True 176 | f = f.astype(np.float64) 177 | if self.regression: 178 | y = y.astype(np.float64) 179 | if len(y.shape) == 1: 180 | y = y.reshape(-1, 1) 181 | return self._fit(f, y) 182 | 183 | def predict(self, f: np.ndarray): 184 | """ 185 | :param f: [N, F], feature matrix 186 | :return: prediction, return shape [N, X] 187 | """ 188 | if not self.fitted: 189 | raise RuntimeError("not fitted, please call fit first") 190 | f = f.astype(np.float64) 191 | logits = f @ self.ms.T 192 | if self.regression: 193 | return logits 194 | return np.argmax(logits, axis=-1) 195 | -------------------------------------------------------------------------------- /NCE.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | def NCE(source_label: np.ndarray, target_label: np.ndarray): 5 | """ 6 | 7 | :param source_label: shape [N], elements in [0, C_s), often got from taken argmax from pre-trained predictions 8 | :param target_label: shape [N], elements in [0, C_t) 9 | :return: 10 | """ 11 | C_t = int(np.max(target_label) + 1) # the number of target classes 12 | C_s = int(np.max(source_label) + 1) # the number of source classes 13 | N = len(source_label) 14 | joint = np.zeros((C_t, C_s), dtype=float) # placeholder for the joint distribution, shape [C_t, C_s] 15 | for s, t in zip(source_label, target_label): 16 | s = int(s) 17 | t = int(t) 18 | joint[t, s] += 1.0 / N 19 | p_z = joint.sum(axis=0, keepdims=True) # shape [1, C_s] 20 | p_target_given_source = (joint / p_z).T # P(y | z), shape [C_s, C_t] 21 | mask = p_z.reshape(-1) != 0 # valid Z, shape [C_s] 22 | p_target_given_source = p_target_given_source[mask] + 1e-20 # remove NaN where p(z) = 0, add 1e-20 to avoid log (0) 23 | entropy_y_given_z = np.sum(- p_target_given_source * np.log(p_target_given_source), axis=1, keepdims=True) # shape [C_s, 1] 24 | conditional_entropy = np.sum(entropy_y_given_z * p_z.reshape((-1, 1))[mask]) # scalar 25 | return - conditional_entropy 26 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # LogME 2 | This is the codebase for the following two papers: 3 | 4 | - [LogME: Practical Assessment of Pre-trained Models for Transfer Learning](http://proceedings.mlr.press/v139/you21b.html), ICML 2021 5 | 6 | - [Ranking and Tuning Pre-trained Models: A New Paradigm for Exploiting Model Hubs](https://arxiv.org/abs/2110.10545), JMLR 2022 7 | 8 | **Note**: the second paper is an extended version of the first conference paper. 9 | 10 | # How to use 11 | 12 | ## Use LogME to assess transferability 13 | 14 | The API looks like sci-kit learn: first initialize an object, and then fit it to your data to get the transferability metric. 15 | 16 | By fitting the features ``f`` and labels ``y``, and you can get a nice score which well correlates with the transfer learning performance (without hyper-parameter tuning). 17 | 18 | **(1) For classification task:** 19 | 20 | ```python 21 | from LogME import LogME 22 | logme = LogME(regression=False) 23 | # f has shape of [N, D], y has shape [N] 24 | score = logme.fit(f, y) 25 | ``` 26 | 27 | **(2) For multi-label classification task:** 28 | 29 | ```python 30 | from LogME import LogME 31 | logme = LogME(regression=True) 32 | # f has shape of [N, D], y has shape [N, C] being the multi-label vector. 33 | score = logme.fit(f, y) 34 | ``` 35 | 36 | **(3) For regression task:** 37 | 38 | ```python 39 | from LogME import LogME 40 | logme = LogME(regression=True) 41 | # f has shape of [N, D], y has shape [N, C] with C regression-labels 42 | score = logme.fit(f, y) 43 | ``` 44 | 45 | Then you can use the ``score`` to quickly select a good pre-trained model. The larger the ``score`` is, the better transfer performance you get. 46 | 47 | Meanwhile, the LogME score can also be used to purely measure the compatibility/transferability between features and labels, just like [this paper](https://arxiv.org/abs/2109.01087) from UC Berkeley. 48 | 49 | ## Ranking and Tuning pre-trained models 50 | 51 | ### Ranking pre-trained models 52 | 53 | ``ranking.py`` contains example code to rank pre-trained models, as well as to save the bayesian weight (m in LogME) for later B-Tuning 54 | 55 | 56 | FGVCAircraft dataset can be downloaded [here](https://www.robots.ox.ac.uk/~vgg/data/fgvc-aircraft/)). 57 | 58 | ```shell 59 | python ranking.py --dataset aircraft --data_path ./data/FGVCAircraft 60 | ``` 61 | 62 | You may get some outputs like the following: 63 | 64 | ```text 65 | Models ranking on aircraft: 66 | [('resnet152', 0.9501244943998941), 67 | ('resnet101', 0.948006158997241), 68 | ('mnasnet1_0', 0.947849273046989), 69 | ('resnet50', 0.9464738509680248), 70 | ('densenet169', 0.9434405008356792), 71 | ('densenet201', 0.9422277504393521), 72 | ('mobilenet_v2', 0.9412819194598648), 73 | ('inception_v3', 0.9398580258195871), 74 | ('densenet121', 0.9382284242364975), 75 | ('googlenet', 0.9338037297080976), 76 | ('resnet34', 0.9301353924624043)] 77 | ``` 78 | 79 | ### Tuning with multiple (heterogeneous) pre-trained models by B-Tuning 80 | 81 | ``b_tuning.py`` contains example code of the proposed B-Tuning. Typically, we can use the top-K models from the output of ``ranking.py``, just as follows: 82 | 83 | ```shell 84 | python b_tuning.py --dataset aircraft --data_path ./data/FGVCAircraft --model resnet50 --teachers resnet152 resnet101 mnasnet1_0 --tradeoff 100 85 | ``` 86 | 87 | Note that we use K=3 here, so the teachers are resnet152/resnet101/mnasnet1_0. We found K=3 is a good choice in general. 88 | 89 | # Code for LEEP and NCE 90 | 91 | We have received several requests for the code of LEEP and NCE, therefore we release the code in this repository to help the community. 92 | 93 | Please see the LEEP.py and NCE.py for details. LEEP/NCE in the paper were calculated by historical code with bugs. New results are available [here](https://github.com/WenWeiTHU/Transfer-Learning-Library/tree/dev-tllib/examples/model_adaption/model_selection), calculated by the LEEP/NCE code in this repo. 94 | 95 | Note that LEEP and NCE requires predictions over the pre-trained classes as input. The typical usage may look like: 96 | 97 | ```python 98 | # get the prediction of shape [N, C_s] from the pre-trained model 99 | # N is the number of samples, C_s is the number of pre-trained classes 100 | import numpy as np 101 | from LEEP import LEEP 102 | from NCE import NCE 103 | 104 | pseudo_source_label = xxx 105 | target_label = xxx # target_label has shape of [N], with its elements in [0, C_t) 106 | 107 | leep_score = LEEP(pseudo_source_label, target_label) 108 | nce_score = NCE(np.argmax(pseudo_source_label, axis=1), target_label) 109 | ``` 110 | 111 | # Citation 112 | 113 | If you find the code useful, please cite the following papers: 114 | 115 | ``` 116 | @inproceedings{you_logme:_2021, 117 | title = {LogME: Practical Assessment of Pre-trained Models for Transfer Learning}, 118 | booktitle = {ICML}, 119 | author = {You, Kaichao and Liu, Yong and Wang, Jianmin and Long, Mingsheng}, 120 | year = {2021} 121 | } 122 | 123 | @article{you_ranking_2022, 124 | title = {Ranking and Tuning Pre-trained Models: A New Paradigm for Exploiting Model Hubs}, 125 | journal = {JMLR}, 126 | author = {You, Kaichao and Liu, Yong and Zhang, Ziyang and Wang, Jianmin and Jordan, Michael I. and Long, Mingsheng}, 127 | year = {2022} 128 | } 129 | ``` 130 | 131 | # Contact 132 | 133 | If you have any question or want to use the code, please contact youkaichao@gmail.com . 134 | -------------------------------------------------------------------------------- /b_tuning.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | from time import time 4 | 5 | import numpy as np 6 | import torch 7 | import torch.nn as nn 8 | import torchvision 9 | import torchvision.models as models 10 | from tensorboardX import SummaryWriter 11 | from torch.utils.data import DataLoader 12 | from torchvision import datasets 13 | from tqdm import tqdm 14 | import torchvision.transforms as transforms 15 | from utils.transforms import get_transforms 16 | from utils.tools import AccuracyMeter, TenCropsTest 17 | 18 | 19 | models_list = ['mobilenet_v2', 'mnasnet1_0', 'densenet121', 'densenet169', 'densenet201', 20 | 'resnet34', 'resnet50', 'resnet101', 'resnet152', 'googlenet', 'inception_v3'] 21 | 22 | 23 | def get_configs(): 24 | parser = argparse.ArgumentParser( 25 | description='Bayesian tuning') 26 | 27 | # train 28 | parser.add_argument('--gpu', default=0, type=int, 29 | help='GPU id for training') 30 | parser.add_argument('--seed', type=int, default=2021) 31 | 32 | parser.add_argument('--batch_size', default=48, type=int) 33 | parser.add_argument('--total_iter', default=9050, type=int) 34 | parser.add_argument('--eval_iter', default=1000, type=int) 35 | parser.add_argument('--save_iter', default=9000, type=int) 36 | parser.add_argument('--print_iter', default=100, type=int) 37 | 38 | # dataset 39 | parser.add_argument('--dataset', default="aircraft", 40 | type=str, help='Name of dataset') 41 | parser.add_argument('--data_path', default="./data/FGVCAircraft", 42 | type=str, help='Path of dataset') 43 | parser.add_argument('--num_workers', default=2, type=int, 44 | help='Num of workers used in dataloading') 45 | 46 | # model 47 | parser.add_argument('--model', default="resnet50", choices=models_list, 48 | type=str, help='Name of NN') 49 | parser.add_argument('--teachers', nargs='+', help='Names of teahcer models') 50 | parser.add_argument('--class_num', default="196", 51 | type=int, help='class number') 52 | 53 | # optimizer 54 | parser.add_argument('--lr', default=1e-3, type=float, 55 | help='Learning rate for training') 56 | parser.add_argument('--gamma', default=0.1, type=float, 57 | help='Gamma value for learning rate decay') 58 | parser.add_argument('--nesterov', default=True, 59 | type=bool, help='nesterov momentum') 60 | parser.add_argument('--momentum', default=0.9, type=float, 61 | help='Momentum value for optimizer') 62 | parser.add_argument('--weight_decay', default=5e-4, 63 | type=float, help='Weight decay value for optimizer') 64 | 65 | # experiment 66 | parser.add_argument('--root', default='.', type=str, 67 | help='Root of the experiment') 68 | parser.add_argument('--name', default='b-tuning', type=str, 69 | help='Name of the experiment') 70 | parser.add_argument('--save_dir', default="model", 71 | type=str, help='Path of saved models') 72 | parser.add_argument('--visual_dir', default="visual", 73 | type=str, help='Path of tensorboard data for training') 74 | parser.add_argument('--temperature', default=0.1, type=float, 75 | metavar='P', help='temperature of logme weight') 76 | parser.add_argument('--tradeoff', default=100, 77 | type=float, help='b-tuning tradeoff') 78 | configs = parser.parse_args() 79 | 80 | return configs 81 | 82 | 83 | def str2list(v): 84 | return v.split(',') 85 | 86 | 87 | def str2bool(v): 88 | return v.lower() in ("yes", "true", "t", "1") 89 | 90 | 91 | def get_writer(log_dir): 92 | return SummaryWriter(log_dir) 93 | 94 | 95 | def get_data_loader(configs): 96 | # data augmentation 97 | data_transforms = get_transforms(resize_size=256, crop_size=224) 98 | 99 | # build dataset 100 | if configs.dataset == 'aircraft': 101 | train_dataset = datasets.ImageFolder( 102 | os.path.join(configs.data_path, 'train'), 103 | transform=data_transforms['train']) 104 | val_dataset = datasets.ImageFolder( 105 | os.path.join(configs.data_path, 'test'), 106 | transform=data_transforms['train']) 107 | test_datasets = { 108 | 'test' + str(i): 109 | datasets.ImageFolder( 110 | os.path.join(configs.data_path, 'test'), 111 | transform=data_transforms["test" + str(i)] 112 | ) 113 | for i in range(10) 114 | } 115 | else: 116 | # try your customized dataset 117 | raise NotImplementedError 118 | 119 | # build dataloader 120 | train_loader = DataLoader(train_dataset, batch_size=configs.batch_size, shuffle=True, 121 | num_workers=configs.num_workers, pin_memory=True) 122 | 123 | val_loader = DataLoader(val_dataset, batch_size=configs.batch_size, shuffle=False, 124 | num_workers=configs.num_workers, pin_memory=True) 125 | test_loaders = { 126 | 'test' + str(i): 127 | DataLoader( 128 | test_datasets["test" + str(i)], 129 | batch_size=4, shuffle=False, num_workers=configs.num_workers 130 | ) 131 | for i in range(10) 132 | } 133 | 134 | return train_loader, val_loader, test_loaders 135 | 136 | 137 | def set_seeds(seed): 138 | np.random.seed(seed) 139 | torch.manual_seed(seed) 140 | torch.cuda.manual_seed(seed) 141 | torch.backends.cudnn.benchmark = False 142 | torch.backends.cudnn.deterministic = True 143 | 144 | 145 | def train(configs, train_loader, val_loader, test_loaders, net, teachers): 146 | train_len = len(train_loader) - 1 147 | train_iter = iter(train_loader) 148 | weight = torch.from_numpy(torch.load(f'logme_{configs.dataset}/weight_{configs.model}.pth')).float().cuda() 149 | 150 | logmes = torch.load(f'logme_{configs.dataset}/results.pth') 151 | pi = [] 152 | for teacher in teachers: 153 | pi.append(logmes[teacher['name']]) 154 | pi = torch.softmax(torch.tensor(pi) / configs.temperature, dim=0).float().cuda() 155 | 156 | # different learning rates for different layers 157 | params_list = [{"params": filter(lambda p: p.requires_grad, net.f_net.parameters())}, 158 | {"params": filter(lambda p: p.requires_grad, net.c_net.parameters()), "lr": configs.lr * 10}] 159 | 160 | # optimizer and scheduler 161 | optimizer = torch.optim.SGD(params_list, lr=configs.lr, weight_decay=configs.weight_decay, 162 | momentum=configs.momentum, nesterov=configs.nesterov) 163 | milestones = [6000] 164 | scheduler = torch.optim.lr_scheduler.MultiStepLR( 165 | optimizer, milestones, gamma=configs.gamma) 166 | 167 | # check visual path 168 | visual_path = os.path.join(configs.visual_dir, configs.name) 169 | if not os.path.exists(visual_path): 170 | os.makedirs(visual_path) 171 | writer = get_writer(visual_path) 172 | 173 | # check model save path 174 | save_path = os.path.join(configs.save_dir, configs.name) 175 | if not os.path.exists(save_path): 176 | os.makedirs(save_path) 177 | 178 | criterion_cls = nn.CrossEntropyLoss() 179 | criterion_dis = nn.MSELoss() 180 | 181 | # start training 182 | for iter_num in range(configs.total_iter): 183 | net.train() 184 | 185 | if iter_num % train_len == 0: 186 | train_iter = iter(train_loader) 187 | 188 | # Data Stage 189 | data_start = time() 190 | 191 | train_inputs, train_labels = next(train_iter) 192 | train_inputs, train_labels = train_inputs.cuda(), train_labels.cuda() 193 | 194 | data_duration = time() - data_start 195 | 196 | # Calc Stage 197 | calc_start = time() 198 | 199 | train_features, train_outputs = net(train_inputs, with_feat=True) 200 | 201 | b_tuning_loss = 0.0 202 | 203 | target_features = torch.matmul(train_features, weight.t()) 204 | 205 | source_features = torch.zeros_like(target_features) 206 | for i, teacher in enumerate(teachers): 207 | with torch.no_grad(): 208 | input = train_inputs 209 | source_features += pi[i] * torch.matmul(teacher['model'](input), teacher['weight'].t()) 210 | 211 | classifier_loss = criterion_cls(train_outputs, train_labels) 212 | b_tuning_loss = criterion_dis(target_features, source_features.detach()) 213 | loss = classifier_loss + configs.tradeoff * b_tuning_loss 214 | 215 | writer.add_scalar('loss/b_tuning_loss', configs.tradeoff * b_tuning_loss, iter_num) 216 | writer.add_scalar('loss/classifier_loss', classifier_loss, iter_num) 217 | writer.add_scalar('loss/total_loss', loss, iter_num) 218 | 219 | net.zero_grad() 220 | optimizer.zero_grad() 221 | 222 | loss.backward() 223 | optimizer.step() 224 | scheduler.step() 225 | 226 | calc_duration = time() - calc_start 227 | 228 | if iter_num % configs.eval_iter == 0: 229 | acc_meter = AccuracyMeter(topk=(1,)) 230 | with torch.no_grad(): 231 | net.eval() 232 | for val_inputs, val_labels in tqdm(val_loader): 233 | val_inputs, val_labels = val_inputs.cuda(), val_labels.cuda() 234 | val_outputs = net(val_inputs) 235 | acc_meter.update(val_outputs, val_labels) 236 | writer.add_scalar('acc/val_acc', acc_meter.avg[1], iter_num) 237 | print( 238 | "Iter: {}/{} Val_Acc: {:2f}".format( 239 | iter_num, configs.total_iter, acc_meter.avg[1]) 240 | ) 241 | acc_meter.reset() 242 | 243 | if iter_num % configs.save_iter == 0 and iter_num > 0: 244 | test_acc = TenCropsTest(test_loaders, net) 245 | writer.add_scalar('acc/test_acc', test_acc, iter_num) 246 | print( 247 | "Iter: {}/{} Test_Acc: {:2f}".format( 248 | iter_num, configs.total_iter, test_acc) 249 | ) 250 | checkpoint = { 251 | 'state_dict': net.state_dict(), 252 | 'iter': iter_num, 253 | 'acc': test_acc, 254 | } 255 | torch.save(checkpoint, 256 | os.path.join(save_path, '{}.pkl'.format(iter_num))) 257 | print("Model Saved.") 258 | 259 | if iter_num % configs.print_iter == 0: 260 | print( 261 | "Iter: {}/{} Loss: {:2f} Loss_CLS: {:2f} Loss_KD: {:2f}, d/c: {}/{}".format(iter_num, configs.total_iter, 262 | loss, classifier_loss, configs.tradeoff * b_tuning_loss, data_duration, calc_duration)) 263 | 264 | 265 | def load_model(configs, pretrained=True, only_feature=False): 266 | model = models.__dict__[configs.model](pretrained=pretrained) 267 | if configs.model in ['mobilenet_v2', 'mnasnet1_0']: 268 | model.feature_dim = model.classifier[-1].in_features 269 | model.classifier[-1] = nn.Identity() 270 | elif configs.model in ['densenet121', 'densenet169', 'densenet201']: 271 | model.feature_dim = model.classifier.in_features 272 | model.classifier = nn.Identity() 273 | elif configs.model in [ 'resnet34', 'resnet50', 'resnet101', 274 | 'resnet152', 'googlenet', 'inception_v3']: 275 | model.feature_dim = model.fc.in_features 276 | model.fc = nn.Identity() 277 | 278 | if only_feature: 279 | return model 280 | 281 | class Net(nn.Module): 282 | def __init__(self, model, feature_dim): 283 | super(Net, self).__init__() 284 | self.f_net = model 285 | self.feature_dim = feature_dim 286 | self.c_net = nn.Linear(feature_dim, configs.class_num) 287 | self.c_net.weight.data.normal_(0, 0.01) 288 | self.c_net.bias.data.fill_(0.0) 289 | 290 | def forward(self, x, with_feat=False): 291 | feature = self.f_net(x) 292 | out = self.c_net(feature) 293 | if with_feat: 294 | return feature, out 295 | else: 296 | return out 297 | 298 | return Net(model, model.feature_dim) 299 | 300 | 301 | def main(): 302 | configs = get_configs() 303 | print(configs) 304 | torch.cuda.set_device(configs.gpu) 305 | set_seeds(configs.seed) 306 | 307 | train_loader, val_loader, test_loaders = get_data_loader(configs) 308 | 309 | net = load_model(configs).cuda() 310 | 311 | student = configs.model 312 | teachers = [] 313 | for teacher_name in configs.teachers: 314 | assert teacher_name in models_list 315 | configs.model = teacher_name 316 | model_t_feat = load_model(configs, only_feature=True).cuda().eval() 317 | model_t = {'name':teacher_name, 318 | 'model': model_t_feat, 319 | 'weight': torch.from_numpy(torch.load(f'logme_{configs.dataset}/weight_{teacher_name}.pth')).float().cuda() 320 | } 321 | teachers.append(model_t) 322 | 323 | configs.model = student 324 | train(configs, train_loader, val_loader, test_loaders, net, teachers) 325 | 326 | 327 | if __name__ == '__main__': 328 | print("PyTorch {}".format(torch.__version__)) 329 | print("TorchVision {}".format(torchvision.__version__)) 330 | main() 331 | -------------------------------------------------------------------------------- /imgs/image-20210222204141915.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thuml/LogME/56551ca94a3295bcc37e6a01b1291911af985f07/imgs/image-20210222204141915.png -------------------------------------------------------------------------------- /imgs/image-20210222204350389.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thuml/LogME/56551ca94a3295bcc37e6a01b1291911af985f07/imgs/image-20210222204350389.png -------------------------------------------------------------------------------- /imgs/image-20210222204712553.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thuml/LogME/56551ca94a3295bcc37e6a01b1291911af985f07/imgs/image-20210222204712553.png -------------------------------------------------------------------------------- /ranking.py: -------------------------------------------------------------------------------- 1 | import argparse, os 2 | import torch 3 | import torchvision.models as models 4 | import torchvision.transforms as transforms 5 | 6 | from torch.utils.data import DataLoader 7 | from torchvision import datasets 8 | from LogME import LogME 9 | import pprint 10 | 11 | models_hub = ['mobilenet_v2', 'mnasnet1_0', 'densenet121', 'densenet169', 'densenet201', 12 | 'resnet34', 'resnet50', 'resnet101', 'resnet152', 'googlenet', 'inception_v3'] 13 | 14 | 15 | def get_configs(): 16 | parser = argparse.ArgumentParser( 17 | description='Ranking pre-trained models') 18 | parser.add_argument('--gpu', default=0, type=int, 19 | help='GPU num for training') 20 | parser.add_argument('--batch_size', default=48, type=int) 21 | 22 | # dataset 23 | parser.add_argument('--dataset', default="aircraft", 24 | type=str, help='Name of dataset') 25 | parser.add_argument('--data_path', default="/data/FGVCAircraft/train", 26 | type=str, help='Path of dataset') 27 | parser.add_argument('--num_workers', default=2, type=int, 28 | help='Num of workers used in dataloading') 29 | # model 30 | configs = parser.parse_args() 31 | 32 | return configs 33 | 34 | 35 | def forward_pass(score_loader, model, fc_layer): 36 | """ 37 | a forward pass on target dataset 38 | :params score_loader: the dataloader for scoring transferability 39 | :params model: the model for scoring transferability 40 | :params fc_layer: the fc layer of the model, for registering hooks 41 | returns 42 | features: extracted features of model 43 | outputs: outputs of model 44 | targets: ground-truth labels of dataset 45 | """ 46 | features = [] 47 | outputs = [] 48 | targets = [] 49 | 50 | def hook_fn_forward(module, input, output): 51 | features.append(input[0].detach().cpu()) 52 | outputs.append(output.detach().cpu()) 53 | 54 | forward_hook = fc_layer.register_forward_hook(hook_fn_forward) 55 | 56 | model.eval() 57 | with torch.no_grad(): 58 | for _, (data, target) in enumerate(score_loader): 59 | targets.append(target) 60 | data = data.cuda() 61 | _ = model(data) 62 | 63 | forward_hook.remove() 64 | features = torch.cat([x for x in features]) 65 | outputs = torch.cat([x for x in outputs]) 66 | targets = torch.cat([x for x in targets]) 67 | 68 | return features, outputs, targets 69 | 70 | 71 | def main(): 72 | configs = get_configs() 73 | torch.cuda.set_device(configs.gpu) 74 | 75 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], 76 | std=[0.229, 0.224, 0.225]) 77 | 78 | if not os.path.isdir(f'logme_{configs.dataset}'): 79 | os.mkdir(f'logme_{configs.dataset}') 80 | score_dict = {} 81 | for model in models_hub: 82 | configs.model = model 83 | if model == 'inception_v3': # inception_v3 is pretrained on 299x299 images 84 | transform=transforms.Compose([ 85 | transforms.Resize((299, 299)), 86 | transforms.ToTensor(), 87 | normalize 88 | ]) 89 | else: 90 | transform=transforms.Compose([ # other models are pretrained on 224x224 images 91 | transforms.Resize((224, 224)), 92 | transforms.ToTensor(), 93 | normalize 94 | ]) 95 | score_dataset = datasets.ImageFolder(configs.data_path, transform=transform) 96 | # or try your customized dataset 97 | score_loader = DataLoader(score_dataset, batch_size=configs.batch_size, shuffle=False, 98 | num_workers=configs.num_workers, pin_memory=True) 99 | score_dict[model] = score_model(configs, score_loader) 100 | results = sorted(score_dict.items(), key=lambda i: i[1], reverse=True) 101 | torch.save(score_dict, f'logme_{configs.dataset}/results.pth') 102 | print(f'Models ranking on {configs.dataset}: ') 103 | pprint.pprint(results) 104 | 105 | 106 | def score_model(configs, score_loader): 107 | print(f'Calc Transferabilities of {configs.model} on {configs.dataset}') 108 | 109 | if configs.model == 'inception_v3': 110 | model = models.__dict__[configs.model](pretrained=True, aux_logits=False).cuda() 111 | else: 112 | model = models.__dict__[configs.model](pretrained=True).cuda() 113 | 114 | # different models has different linear projection names 115 | if configs.model in ['mobilenet_v2', 'mnasnet1_0']: 116 | fc_layer = model.classifier[-1] 117 | elif configs.model in ['densenet121', 'densenet169', 'densenet201']: 118 | fc_layer = model.classifier 119 | elif configs.model in ['resnet34', 'resnet50', 'resnet101', 'resnet152', 'googlenet', 'inception_v3']: 120 | fc_layer = model.fc 121 | else: 122 | # try your customized model 123 | raise NotImplementedError 124 | 125 | print('Conducting features extraction...') 126 | features, outputs, targets = forward_pass(score_loader, model, fc_layer) 127 | # predictions = F.softmax(outputs) 128 | 129 | print('Conducting transferability calculation...') 130 | logme = LogME(regression=False) 131 | score = logme.fit(features.numpy(), targets.numpy()) 132 | 133 | # save calculated bayesian weight 134 | torch.save(logme.ms, f'logme_{configs.dataset}/weight_{configs.model}.pth') 135 | 136 | print(f'LogME of {configs.model}: {score}\n') 137 | return score 138 | 139 | 140 | if __name__ == '__main__': 141 | main() 142 | -------------------------------------------------------------------------------- /utils/tools.py: -------------------------------------------------------------------------------- 1 | import torch, os 2 | from tqdm import trange 3 | 4 | 5 | class AverageMeter(object): 6 | """Computes and stores the average and current value""" 7 | 8 | def __init__(self): 9 | self.reset() 10 | 11 | def reset(self): 12 | self.val = 0 13 | self.avg = 0 14 | self.sum = 0 15 | self.count = 0 16 | 17 | def update(self, val, n=1): 18 | self.val = val 19 | self.sum += val * n 20 | self.count += n 21 | self.avg = self.sum / self.count 22 | 23 | 24 | class OnlineMeter(object): 25 | """Computes and stores the average and variance/std values of tensor""" 26 | 27 | def __init__(self): 28 | self.mean = torch.FloatTensor(1).fill_(-1) 29 | self.M2 = torch.FloatTensor(1).zero_() 30 | self.count = 0. 31 | self.needs_init = True 32 | 33 | def reset(self, x): 34 | self.mean = x.new(x.size()).zero_() 35 | self.M2 = x.new(x.size()).zero_() 36 | self.count = 0. 37 | self.needs_init = False 38 | 39 | def update(self, x): 40 | self.val = x 41 | if self.needs_init: 42 | self.reset(x) 43 | self.count += 1 44 | delta = x - self.mean 45 | self.mean.add_(delta / self.count) 46 | delta2 = x - self.mean 47 | self.M2.add_(delta * delta2) 48 | 49 | @property 50 | def var(self): 51 | if self.count < 2: 52 | return self.M2.clone().zero_() 53 | return self.M2 / (self.count - 1) 54 | 55 | @property 56 | def std(self): 57 | return self.var().sqrt() 58 | 59 | 60 | def accuracy(output, target, topk=(1,)): 61 | """Computes the precision@k for the specified values of k""" 62 | maxk = max(topk) 63 | batch_size = target.size(0) 64 | 65 | _, pred = output.topk(maxk, 1, True, True) 66 | pred = pred.t().type_as(target) 67 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 68 | 69 | res = [] 70 | for k in topk: 71 | correct_k = correct[:k].view(-1).float().sum(0) 72 | res.append(correct_k.mul_(100.0 / batch_size)) 73 | 74 | return res 75 | 76 | 77 | class AccuracyMeter(object): 78 | """Computes and stores the average and current topk accuracy""" 79 | 80 | def __init__(self, topk=(1,)): 81 | self.topk = topk 82 | self.reset() 83 | 84 | def reset(self): 85 | self._meters = {} 86 | for k in self.topk: 87 | self._meters[k] = AverageMeter() 88 | 89 | def update(self, output, target): 90 | n = target.nelement() 91 | acc_vals = accuracy(output, target, self.topk) 92 | for i, k in enumerate(self.topk): 93 | self._meters[k].update(acc_vals[i]) 94 | 95 | @property 96 | def val(self): 97 | return {n: meter.val for (n, meter) in self._meters.items()} 98 | 99 | @property 100 | def avg(self): 101 | return {n: meter.avg for (n, meter) in self._meters.items()} 102 | 103 | @property 104 | def avg_error(self): 105 | return {n: 100. - meter.avg for (n, meter) in self._meters.items()} 106 | 107 | 108 | def TenCropsTest(loader, net): 109 | with torch.no_grad(): 110 | net.eval() 111 | start_test = True 112 | val_len = len(loader['test0']) 113 | iter_val = [iter(loader['test' + str(i)]) for i in range(10)] 114 | for _ in trange(val_len): 115 | data = [iter_val[j].next() for j in range(10)] 116 | inputs = [data[j][0] for j in range(10)] 117 | labels = data[0][1] 118 | for j in range(10): 119 | inputs[j] = inputs[j].cuda() 120 | labels = labels.cuda() 121 | outputs = [] 122 | for j in range(10): 123 | output = net(inputs[j]) 124 | outputs.append(output) 125 | outputs = sum(outputs) 126 | if start_test: 127 | all_outputs = outputs.data.float() 128 | all_labels = labels.data.float() 129 | start_test = False 130 | else: 131 | all_outputs = torch.cat((all_outputs, outputs.data.float()), 0) 132 | all_labels = torch.cat((all_labels, labels.data.float()), 0) 133 | acc_meter = AccuracyMeter(topk=(1,)) 134 | acc_meter.update(all_outputs, all_labels) 135 | 136 | return acc_meter.avg[1] 137 | -------------------------------------------------------------------------------- /utils/transforms.py: -------------------------------------------------------------------------------- 1 | from PIL import Image 2 | from torchvision import transforms 3 | 4 | 5 | class ResizeImage(): 6 | def __init__(self, size): 7 | if isinstance(size, int): 8 | self.size = (int(size), int(size)) 9 | else: 10 | self.size = size 11 | 12 | def __call__(self, img): 13 | th, tw = self.size 14 | return img.resize((th, tw)) 15 | 16 | 17 | class PlaceCrop(object): 18 | """Crops the given PIL.Image at the particular index. 19 | Args: 20 | size (sequence or int): Desired output size of the crop. If size is an 21 | int instead of sequence like (w, h), a square crop (size, size) is 22 | made. 23 | """ 24 | 25 | def __init__(self, size, start_x, start_y): 26 | if isinstance(size, int): 27 | self.size = (int(size), int(size)) 28 | else: 29 | self.size = size 30 | self.start_x = start_x 31 | self.start_y = start_y 32 | 33 | def __call__(self, img): 34 | """ 35 | Args: 36 | img (PIL.Image): Image to be cropped. 37 | Returns: 38 | PIL.Image: Cropped image. 39 | """ 40 | th, tw = self.size 41 | return img.crop((self.start_x, self.start_y, self.start_x + tw, self.start_y + th)) 42 | 43 | 44 | class ForceFlip(object): 45 | """Horizontally flip the given PIL.Image randomly with a probability of 0.5.""" 46 | 47 | def __call__(self, img): 48 | """ 49 | Args: 50 | img (PIL.Image): Image to be flipped. 51 | Returns: 52 | PIL.Image: Randomly flipped image. 53 | """ 54 | return img.transpose(Image.FLIP_LEFT_RIGHT) 55 | 56 | 57 | def transform_train(resize_size=256, crop_size=224): 58 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], 59 | std=[0.229, 0.224, 0.225]) 60 | return transforms.Compose([ 61 | ResizeImage(resize_size), 62 | transforms.RandomResizedCrop(crop_size), 63 | transforms.RandomHorizontalFlip(), 64 | transforms.ToTensor(), 65 | normalize 66 | ]) 67 | 68 | 69 | def transform_val(resize_size=256, crop_size=224): 70 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], 71 | std=[0.229, 0.224, 0.225]) 72 | start_center = (resize_size - crop_size - 1) / 2 73 | 74 | return transforms.Compose([ 75 | ResizeImage(resize_size), 76 | PlaceCrop(crop_size, start_center, start_center), 77 | transforms.ToTensor(), 78 | normalize 79 | ]) 80 | 81 | 82 | def transform_test(resize_size=256, crop_size=224): 83 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], 84 | std=[0.229, 0.224, 0.225]) 85 | # ten crops for image test 86 | start_first = 0 87 | start_center = (resize_size - crop_size - 1) / 2 88 | start_last = resize_size - crop_size - 1 89 | data_transforms = {} 90 | data_transforms['test0'] = transforms.Compose([ 91 | ResizeImage(resize_size), ForceFlip(), 92 | PlaceCrop(crop_size, start_first, start_first), 93 | transforms.ToTensor(), 94 | normalize 95 | ]) 96 | data_transforms['test1'] = transforms.Compose([ 97 | ResizeImage(resize_size), ForceFlip(), 98 | PlaceCrop(crop_size, start_last, start_last), 99 | transforms.ToTensor(), 100 | normalize 101 | ]) 102 | data_transforms['test2'] = transforms.Compose([ 103 | ResizeImage(resize_size), ForceFlip(), 104 | PlaceCrop(crop_size, start_last, start_first), 105 | transforms.ToTensor(), 106 | normalize 107 | ]) 108 | data_transforms['test3'] = transforms.Compose([ 109 | ResizeImage(resize_size), ForceFlip(), 110 | PlaceCrop(crop_size, start_first, start_last), 111 | transforms.ToTensor(), 112 | normalize 113 | ]) 114 | data_transforms['test4'] = transforms.Compose([ 115 | ResizeImage(resize_size), ForceFlip(), 116 | PlaceCrop(crop_size, start_center, start_center), 117 | transforms.ToTensor(), 118 | normalize 119 | ]) 120 | data_transforms['test5'] = transforms.Compose([ 121 | ResizeImage(resize_size), 122 | PlaceCrop(crop_size, start_first, start_first), 123 | transforms.ToTensor(), 124 | normalize 125 | ]) 126 | data_transforms['test6'] = transforms.Compose([ 127 | ResizeImage(resize_size), 128 | PlaceCrop(crop_size, start_last, start_last), 129 | transforms.ToTensor(), 130 | normalize 131 | ]) 132 | data_transforms['test7'] = transforms.Compose([ 133 | ResizeImage(resize_size), 134 | PlaceCrop(crop_size, start_last, start_first), 135 | transforms.ToTensor(), 136 | normalize 137 | ]) 138 | data_transforms['test8'] = transforms.Compose([ 139 | ResizeImage(resize_size), 140 | PlaceCrop(crop_size, start_first, start_last), 141 | transforms.ToTensor(), 142 | normalize 143 | ]) 144 | data_transforms['test9'] = transforms.Compose([ 145 | ResizeImage(resize_size), 146 | PlaceCrop(crop_size, start_center, start_center), 147 | transforms.ToTensor(), 148 | normalize 149 | ]) 150 | 151 | return data_transforms 152 | 153 | 154 | def get_transforms(resize_size=256, crop_size=224): 155 | transforms = { 156 | 'train': transform_train(resize_size, crop_size), 157 | 'val': transform_val(resize_size, crop_size), 158 | } 159 | transforms.update(transform_test(resize_size, crop_size)) 160 | 161 | return transforms 162 | --------------------------------------------------------------------------------