├── .gitignore ├── README.md ├── classification_tasks ├── data.py ├── eval.py ├── eval_feature_rank.py ├── feature_rank.ipynb ├── losses.py ├── models.py ├── models_densenet.py ├── plots.py ├── randomaug.py ├── sam.py ├── train.py ├── utils.py ├── utils_eval.py └── utils_train.py ├── contrastive_text_image_learning ├── clip_plots.ipynb ├── data.py ├── eval.py ├── eval_imagenet.py ├── main.py ├── models.py ├── utils.py └── vision_transformer │ ├── .github │ └── workflows │ │ └── build.yml │ ├── .gitignore │ ├── .style.yapf │ ├── CONTRIBUTING.md │ ├── LICENSE │ ├── README.md │ ├── lit.ipynb │ ├── mixer_figure.png │ ├── model_cards │ └── lit.md │ ├── setup.py │ ├── version.py │ ├── vit_figure.png │ ├── vit_jax.ipynb │ ├── vit_jax │ ├── __init__.py │ ├── checkpoint.py │ ├── checkpoint_test.py │ ├── configs │ │ ├── README.md │ │ ├── __init__.py │ │ ├── augreg.py │ │ ├── common.py │ │ ├── inference_time.py │ │ ├── mixer_base16_cifar10.py │ │ ├── models.py │ │ └── vit.py │ ├── inference_time.py │ ├── inference_time_test.py │ ├── input_pipeline.py │ ├── main.py │ ├── models.py │ ├── models_lit.py │ ├── models_mixer.py │ ├── models_resnet.py │ ├── models_test.py │ ├── models_vit.py │ ├── preprocess.py │ ├── preprocess_test.py │ ├── requirements-tpu.txt │ ├── requirements.txt │ ├── test_utils.py │ ├── train.py │ ├── train_test.py │ └── utils.py │ └── vit_jax_augreg.ipynb ├── requirements.txt ├── sam_low_rank_summary.png └── two_layer_nets ├── fc_nets.py ├── fc_nets_two_layer.ipynb └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | .DS_Store 2 | 3 | # data files 4 | MNIST_DATA 5 | 6 | # compiled python files 7 | *.pyc 8 | 9 | .idea/ 10 | .vscode/ 11 | 12 | images/ 13 | plots/ 14 | plots_wd/ 15 | extended_stats/ 16 | metrics_loss_surface/ 17 | exps/* 18 | exps_old/ 19 | exps_sgdm/ 20 | exps_backup/ 21 | exps_backup_29jan/ 22 | logs/ 23 | logs_eval/ 24 | results_dir/ 25 | debug_logs/ 26 | apex/ 27 | aa_eval/ 28 | models/ 29 | models_simple/ 30 | deltas/ 31 | nohup.out 32 | data/ 33 | -------------------------------------------------------------------------------- /classification_tasks/eval.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import numpy as np 4 | import torch 5 | import time 6 | import data 7 | import models 8 | import utils_eval 9 | import losses 10 | 11 | 12 | def get_args(): 13 | parser = argparse.ArgumentParser() 14 | parser.add_argument('--gpu', default=0, type=int) 15 | parser.add_argument('--dataset', default='cifar10', choices=['mnist', 'svhn', 'cifar10', 'cifar10_horse_car', 'cifar10_dog_cat', 'uniform_noise'], type=str) 16 | parser.add_argument('--model', default='resnet18', choices=['resnet18', 'cnn', 'fc', 'linear', 'lenet'], type=str) 17 | parser.add_argument('--set', default='test', type=str, choices=['train', 'test']) 18 | parser.add_argument('--weight_modification', default='none', type=str, choices=['none', 'sam_m_full', 'sam_m_1']) 19 | parser.add_argument('--model_path', default='2020-05-01 19:16:54.509 dataset=cifar10 model=resnet18 eps=8.0 attack=fgsm attack_init=zero fgsm_alpha=1.25 epochs=30 pgd=2.0-10 grad_align_cos_lambda=0.2 cure_lambda=0.0 lr_max=0.3 seed=0 epoch=30', 20 | type=str, help='model name') 21 | parser.add_argument('--model_short', default='noname', type=str, help='model name') 22 | parser.add_argument('--seed', default=0, type=int) 23 | parser.add_argument('--eps', default=8, type=float) 24 | parser.add_argument('--n_eval', default=-1, type=int, help='#examples to evaluate on') 25 | parser.add_argument('--loss', default='ce', choices=['ce', 'ce_offset', 'gce', 'smallest_k_ce'], type=str, help='Loss type.') 26 | parser.add_argument('--sam_rho', default=0.2, type=float, help='step size for SAM (sharpness-aware minimization)') 27 | parser.add_argument('--p_label_noise', default=0.0, type=float, help='label noise for evaluation') 28 | parser.add_argument('--activation', default='relu', type=str, help='currently supported only for resnet. relu or softplus* where * corresponds to the softplus alpha') 29 | parser.add_argument('--pgd_rr_n_iter', default=50, type=int, help='pgd rr number of iterations') 30 | parser.add_argument('--pgd_rr_n_restarts', default=10, type=int, help='pgd rr number of restarts') 31 | parser.add_argument('--n_layers', default=1, type=int, help='#layers on each conv layer (for model in [fc, cnn])') 32 | parser.add_argument('--model_width', default=64, type=int, help='model width (# conv filters on the first layer for ResNets)') 33 | parser.add_argument('--n_filters_cnn', default=16, type=int, help='#filters on each conv layer (for model==cnn)') 34 | parser.add_argument('--n_hidden_fc', default=1024, type=int, help='#filters on each conv layer (for model==fc)') 35 | parser.add_argument('--batch_size_eval', default=512, type=int, help='batch size for evaluation') 36 | parser.add_argument('--half_prec', action='store_true', help='eval in half precision') 37 | parser.add_argument('--early_stopped_model', action='store_true', help='eval the best model according to pgd_acc evaluated every k iters (typically, k=200)') 38 | parser.add_argument('--eval_grad_norm', action='store_true', help='evaluate the gradient norm') 39 | parser.add_argument('--aa_eval', action='store_true', help='perform autoattack evaluation') 40 | return parser.parse_args() 41 | 42 | 43 | start_time = time.time() 44 | args = get_args() 45 | eps = args.eps / 255 46 | half_prec = args.half_prec # for more reliable evaluations: keep in the single precision 47 | print_stats = False 48 | n_cls = 2 if args.dataset in ['cifar10_horse_car', 'cifar10_dog_cat'] else 10 49 | 50 | os.environ["CUDA_VISIBLE_DEVICES"] = str(args.gpu) 51 | 52 | np.set_printoptions(precision=4, suppress=True) 53 | np.random.seed(args.seed) 54 | torch.manual_seed(args.seed) 55 | torch.cuda.manual_seed(args.seed) 56 | 57 | scaler = torch.cuda.amp.GradScaler(enabled=False) 58 | loss_dict = { 59 | 'ce': losses.cross_entropy(), 60 | 'ce_offset': losses.cross_entropy_with_offset(loss_offset=0.1), 61 | 'gce': losses.generalized_cross_entropy(q=0.7), 62 | 'smallest_k_ce': losses.smallest_k_cross_entropy(frac_rm_per_batch=0.0) 63 | } 64 | loss_f = loss_dict[args.loss] 65 | 66 | model = models.get_model(args.model, n_cls, args.half_prec, data.shapes_dict[args.dataset], args.model_width, args.activation) 67 | model = model.cuda().eval() 68 | 69 | model_dict = torch.load('models/{}.pth'.format(args.model_path)) 70 | model_dict = model_dict['best'] if args.early_stopped_model else model_dict['last'] 71 | model.load_state_dict({k: v for k, v in model_dict.items() if 'model_preact_hl1' not in k}) 72 | 73 | opt = torch.optim.SGD(model.parameters(), lr=0, momentum=0.9) 74 | 75 | # important to exclude the validation samples to get the correct training error 76 | n_val = int(0.1 * data.shapes_dict[args.dataset][0]) 77 | val_indices = np.random.permutation(data.shapes_dict[args.dataset][0])[:n_val] 78 | train_batches = data.get_loaders(args.dataset, args.n_eval, args.batch_size_eval, split='train', shuffle=False, 79 | data_augm=False, drop_last=False, p_label_noise=args.p_label_noise, val_indices=val_indices) 80 | test_batches = data.get_loaders(args.dataset, args.n_eval, args.batch_size_eval, split='test', shuffle=False, 81 | data_augm=False, drop_last=False, p_label_noise=args.p_label_noise) 82 | 83 | 84 | # import ipdb;ipdb.set_trace() 85 | if args.eval_grad_norm: 86 | n_ex, grad_norm_total = 0, 0.0 87 | for i, (X, X_augm2, y, _, ln) in enumerate(train_batches): 88 | X, y = X.cuda(), y.cuda() 89 | 90 | output = model(X) 91 | loss = loss_f(output, y) 92 | loss.backward() 93 | grad_norm = sum([torch.sum(p.grad**2) for p in model.parameters()])**0.5 94 | grad_norm_total += grad_norm.item() 95 | n_ex += y.size(0) 96 | opt.zero_grad() 97 | 98 | grad_norm_total /= n_ex 99 | print(grad_norm_total) 100 | 101 | 102 | train_err, train_loss, _ = utils_eval.rob_err(train_batches, model, 0, 0, scaler, 0, 0) 103 | test_err, test_loss, _ = utils_eval.rob_err(test_batches, model, 0, 0, scaler, 0, 0) 104 | 105 | print('err={:.2%}/{:.2%}, loss={:.4f}/{:.4f}'.format(train_err, test_err, train_loss, test_loss)) 106 | 107 | if args.aa_eval: 108 | from autoattack import autoattack 109 | images, labels, _, _ = data.get_xy_from_loader(test_batches) 110 | adversary = autoattack.AutoAttack(model, norm='Linf', eps=eps) 111 | x_adv = adversary.run_standard_evaluation(images, labels, bs=args.batch_size_eval) 112 | 113 | 114 | time_elapsed = time.time() - start_time 115 | print('Done in {:.2f}m'.format((time.time() - start_time) / 60)) 116 | 117 | -------------------------------------------------------------------------------- /classification_tasks/eval_feature_rank.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import numpy as np 4 | import torch 5 | import time 6 | import data 7 | import models 8 | import utils_eval 9 | import losses 10 | 11 | 12 | def get_args(): 13 | parser = argparse.ArgumentParser() 14 | parser.add_argument('--gpu', default=0, type=int) 15 | parser.add_argument('--dataset', default='cifar10', choices=['mnist', 'svhn', 'cifar10', 'cifar100', 'cifar10_horse_car', 'cifar10_dog_cat', 'uniform_noise'], type=str) 16 | parser.add_argument('--model', default='resnet18_plain', choices=['resnet18_plain', 'resnet18'], type=str) 17 | parser.add_argument('--set', default='test', type=str, choices=['train', 'test']) 18 | parser.add_argument('--model_path', default='', type=str, help='model name') 19 | parser.add_argument('--seed', default=0, type=int) 20 | parser.add_argument('--n_eval', default=-1, type=int, help='#examples to evaluate on') 21 | parser.add_argument('--batch_size_eval', default=1024, type=int, help='batch size') 22 | parser.add_argument('--model_width', default=64, type=int, help='model width (# conv filters on the first layer for ResNets)') 23 | return parser.parse_args() 24 | 25 | 26 | start_time = time.time() 27 | args = get_args() 28 | rho = args.model_path.split('sam_rho=')[1].split(' ')[0] 29 | n_cls = 100 if args.dataset == 'cifar100' else 10 30 | scaler = torch.cuda.amp.GradScaler(enabled=False) 31 | loss_f = losses.cross_entropy() 32 | 33 | os.environ["CUDA_VISIBLE_DEVICES"] = str(args.gpu) 34 | 35 | np.set_printoptions(precision=4, suppress=True) 36 | np.random.seed(args.seed) 37 | torch.manual_seed(args.seed) 38 | torch.cuda.manual_seed(args.seed) 39 | 40 | model = models.get_model(args.model, n_cls, False, data.shapes_dict[args.dataset], args.model_width).cuda().eval() 41 | 42 | model_dict = torch.load('models/{}.pth'.format(args.model_path))['last'] 43 | model.load_state_dict({k: v for k, v in model_dict.items()}) 44 | 45 | # important to exclude the validation samples to get the correct training error 46 | n_val = int(0.001 * data.shapes_dict[args.dataset][0]) 47 | val_indices = np.random.permutation(data.shapes_dict[args.dataset][0])[:n_val] 48 | train_batches = data.get_loaders(args.dataset, args.n_eval, args.batch_size_eval, split='train', shuffle=False, 49 | data_augm=False, drop_last=False, p_label_noise=0.0, val_indices=val_indices) 50 | test_batches = data.get_loaders(args.dataset, args.n_eval, args.batch_size_eval, split='test', shuffle=False, 51 | data_augm=False, drop_last=False, p_label_noise=0.0) 52 | 53 | 54 | train_err, train_loss, _ = utils_eval.rob_err(train_batches, model, 0, 0, scaler, 0, 0) 55 | test_err, test_loss, _ = utils_eval.rob_err(test_batches, model, 0, 0, scaler, 0, 0) 56 | print('test_err={:.2%}, train_err={:.2%}, train_loss={:.5f}'.format(test_err, train_err, train_loss)) 57 | 58 | 59 | feature_sing_vals, avg_sparsities, ns_active_relus_0p, ns_active_relus_1p, ns_active_relus_5p, ns_active_relus_10p = [], [], [], [], [], [] 60 | 61 | for i in [1, 2, 3, 4, 5]: 62 | feature_sing_vals += [utils_eval.compute_feature_sing_vals(train_batches, model, return_block=i)] 63 | 64 | phi = utils_eval.compute_feature_matrix(train_batches, model, return_block=i) 65 | relu_threshold = phi.max() / 20 66 | avg_sparsities += [(phi > relu_threshold).mean()] 67 | ns_active_relus_0p += [((phi > relu_threshold).sum(0) > phi.shape[0] * 0.0).sum()] 68 | ns_active_relus_1p += [((phi > relu_threshold).sum(0) > phi.shape[0] * 0.01).sum()] 69 | ns_active_relus_5p += [((phi > relu_threshold).sum(0) > phi.shape[0] * 0.05).sum()] 70 | ns_active_relus_10p += [((phi > relu_threshold).sum(0) > phi.shape[0] * 0.1).sum()] 71 | 72 | metrics = { 73 | 'rho': rho, 74 | 'test_err': test_err, 75 | 'feature_ranks': [np.sum(np.cumsum(svals**2) <= np.sum(svals**2) * 0.99) + 1 for svals in feature_sing_vals], 76 | 'avg_sparsities': avg_sparsities, 77 | 'ns_active_relus_0p': ns_active_relus_0p, 78 | 'ns_active_relus_1p': ns_active_relus_1p, 79 | 'ns_active_relus_5p': ns_active_relus_5p, 80 | 'ns_active_relus_10p': ns_active_relus_10p, 81 | } 82 | print(str(metrics) + ',') 83 | time_elapsed = time.time() - start_time 84 | print('Done in {:.2f}m'.format((time.time() - start_time) / 60)) 85 | 86 | -------------------------------------------------------------------------------- /classification_tasks/losses.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | 4 | 5 | def cross_entropy(reduction=True, label_smoothing=0.0): 6 | def loss_f(logits, y): 7 | loss = F.cross_entropy(logits, y, reduction='mean' if reduction else 'none', label_smoothing=label_smoothing) 8 | return loss 9 | 10 | return loss_f 11 | 12 | 13 | def cross_entropy_with_offset(loss_offset): 14 | """ Works only for binary classification. """ 15 | def loss_f(logits, y, reduction=True): 16 | assert max(y).item() <= 1 and min(y).item() >= 0 # i.e. binary labels directly encoded as 0 or 1 17 | y_plus_minus = 2 * (y - 0.5) 18 | loss = torch.log(1 + torch.exp(-y_plus_minus * (logits[:, 1] - logits[:, 0]) + loss_offset)) 19 | return loss.mean() if reduction else loss 20 | 21 | return loss_f 22 | 23 | 24 | def logistic_loss(): 25 | """ Works only for binary classification. Assumes logits have only 1 prediction. """ 26 | def loss_f(logits, y, reduction=True): 27 | assert max(y).item() <= 1 and min(y).item() >= 0 # i.e. binary labels directly encoded as 0 or 1 28 | y_plus_minus = 2 * (y - 0.5) 29 | loss = torch.log(1 + torch.exp(-y_plus_minus * (logits[:, 1] - logits[:, 0]))) 30 | return loss.mean() if reduction else loss 31 | 32 | return loss_f 33 | 34 | 35 | def generalized_cross_entropy(q): 36 | def loss_f(logits, y, reduction=True): 37 | p = F.softmax(logits, dim=1) 38 | p_y = p[range(p.shape[0]), y] 39 | loss = 1/q * (1 - p_y**q) 40 | return loss.mean() if reduction else loss 41 | 42 | return loss_f 43 | 44 | 45 | def smallest_k_cross_entropy(frac_rm_per_batch): 46 | def loss_f(logits, y, reduction=True): 47 | k_keep = y.shape[0] - int(frac_rm_per_batch * y.shape[0]) 48 | loss = F.cross_entropy(logits, y, reduction='none') 49 | loss = torch.topk(loss, k_keep, largest=False)[0] # take `k_keep` smallest losses 50 | return loss.mean() if reduction else loss 51 | 52 | return loss_f 53 | 54 | 55 | def logistic_loss_der(logits, y): 56 | """ Works only for binary classification. Assumes logits have only 1 prediction. """ 57 | y_plus_minus = 2 * (y - 0.5) 58 | der = -y_plus_minus/(1 + torch.exp(y_plus_minus * (logits[:, 1] - logits[:, 0]))) 59 | return der 60 | 61 | 62 | def square_loss(reduction=True): 63 | def loss_f(logits, y, reduction=True): 64 | loss = 0.5*torch.mean((logits - F.one_hot(y))**2, axis=1) 65 | return loss.mean() if reduction else loss 66 | 67 | return loss_f 68 | 69 | -------------------------------------------------------------------------------- /classification_tasks/models_densenet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | import torch.nn as nn 4 | import torch.optim as optim 5 | 6 | import torch.nn.functional as F 7 | from torch.autograd import Variable 8 | 9 | import torchvision.datasets as dset 10 | import torchvision.transforms as transforms 11 | from torch.utils.data import DataLoader 12 | 13 | import torchvision.models as models 14 | 15 | import sys 16 | import math 17 | 18 | class Bottleneck(nn.Module): 19 | def __init__(self, nChannels, growthRate): 20 | super(Bottleneck, self).__init__() 21 | interChannels = 4*growthRate 22 | self.bn1 = nn.BatchNorm2d(nChannels) 23 | self.conv1 = nn.Conv2d(nChannels, interChannels, kernel_size=1, 24 | bias=False) 25 | self.bn2 = nn.BatchNorm2d(interChannels) 26 | self.conv2 = nn.Conv2d(interChannels, growthRate, kernel_size=3, 27 | padding=1, bias=False) 28 | 29 | def forward(self, x): 30 | out = self.conv1(F.relu(self.bn1(x))) 31 | out = self.conv2(F.relu(self.bn2(out))) 32 | out = torch.cat((x, out), 1) 33 | return out 34 | 35 | class SingleLayer(nn.Module): 36 | def __init__(self, nChannels, growthRate): 37 | super(SingleLayer, self).__init__() 38 | self.bn1 = nn.BatchNorm2d(nChannels) 39 | self.conv1 = nn.Conv2d(nChannels, growthRate, kernel_size=3, 40 | padding=1, bias=False) 41 | 42 | def forward(self, x): 43 | out = self.conv1(F.relu(self.bn1(x))) 44 | out = torch.cat((x, out), 1) 45 | return out 46 | 47 | class Transition(nn.Module): 48 | def __init__(self, nChannels, nOutChannels): 49 | super(Transition, self).__init__() 50 | self.bn1 = nn.BatchNorm2d(nChannels) 51 | self.conv1 = nn.Conv2d(nChannels, nOutChannels, kernel_size=1, 52 | bias=False) 53 | 54 | def forward(self, x, return_after_relu=False): 55 | if return_after_relu: 56 | out = F.relu(self.bn1(x)) 57 | else: 58 | out = self.conv1(F.relu(self.bn1(x))) 59 | out = F.avg_pool2d(out, 2) 60 | return out 61 | 62 | 63 | class DenseNet(nn.Module): 64 | def __init__(self, growthRate=12, depth=100, reduction=0.5, nClasses=10, bottleneck=True): 65 | super(DenseNet, self).__init__() 66 | 67 | nDenseBlocks = (depth-4) // 3 68 | if bottleneck: 69 | nDenseBlocks //= 2 70 | 71 | nChannels = 2*growthRate 72 | self.conv1 = nn.Conv2d(3, nChannels, kernel_size=3, padding=1, 73 | bias=False) 74 | self.dense1 = self._make_dense(nChannels, growthRate, nDenseBlocks, bottleneck) 75 | nChannels += nDenseBlocks*growthRate 76 | nOutChannels = int(math.floor(nChannels*reduction)) 77 | self.trans1 = Transition(nChannels, nOutChannels) 78 | 79 | nChannels = nOutChannels 80 | self.dense2 = self._make_dense(nChannels, growthRate, nDenseBlocks, bottleneck) 81 | nChannels += nDenseBlocks*growthRate 82 | nOutChannels = int(math.floor(nChannels*reduction)) 83 | self.trans2 = Transition(nChannels, nOutChannels) 84 | 85 | nChannels = nOutChannels 86 | self.dense3 = self._make_dense(nChannels, growthRate, nDenseBlocks, bottleneck) 87 | nChannels += nDenseBlocks*growthRate 88 | 89 | self.bn1 = nn.BatchNorm2d(nChannels) 90 | self.fc = nn.Linear(nChannels, nClasses) 91 | 92 | for m in self.modules(): 93 | if isinstance(m, nn.Conv2d): 94 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 95 | m.weight.data.normal_(0, math.sqrt(2. / n)) 96 | elif isinstance(m, nn.BatchNorm2d): 97 | m.weight.data.fill_(1) 98 | m.bias.data.zero_() 99 | elif isinstance(m, nn.Linear): 100 | m.bias.data.zero_() 101 | 102 | def _make_dense(self, nChannels, growthRate, nDenseBlocks, bottleneck): 103 | layers = [] 104 | for i in range(int(nDenseBlocks)): 105 | if bottleneck: 106 | layers.append(Bottleneck(nChannels, growthRate)) 107 | else: 108 | layers.append(SingleLayer(nChannels, growthRate)) 109 | nChannels += growthRate 110 | return nn.Sequential(*layers) 111 | 112 | def forward(self, x, return_features=False, return_block=5): 113 | out = self.conv1(x) 114 | if return_features and return_block == 1: 115 | return out 116 | if return_features and return_block == 2: 117 | out = self.trans1(self.dense1(out), return_after_relu=True) 118 | return out 119 | else: 120 | out = self.trans1(self.dense1(out)) 121 | if return_features and return_block == 3: 122 | out = self.trans2(self.dense2(out), return_after_relu=True) 123 | return out 124 | else: 125 | out = self.trans2(self.dense2(out)) 126 | out = self.dense3(out) 127 | out = F.relu(self.bn1(out)) 128 | if return_features and return_block == 4: 129 | return out 130 | 131 | out = F.avg_pool2d(out, out.shape[-1]).squeeze(3).squeeze(2) 132 | if return_features and return_block == 5: 133 | return out 134 | out = F.log_softmax(self.fc(out), dim=-1) 135 | return out 136 | 137 | -------------------------------------------------------------------------------- /classification_tasks/plots.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import seaborn as sns 3 | import matplotlib.pyplot as plt 4 | 5 | 6 | def plot_loss(losses, training, random_step_fgsm, epoch, loss_type, i_ex): 7 | sns.set() 8 | plot_name = 'training={}_random_step_fgsm={}_loss_epoch={}_i_ex={}_type={}'.format( 9 | training, random_step_fgsm, epoch, i_ex, loss_type) 10 | marker_size, line_width = 5.0, 0.5 11 | interp_vals = np.arange(len(losses)) 12 | x_values = (interp_vals - max(interp_vals)/2) / (max(interp_vals)/2) 13 | ax = sns.lineplot(x_values, losses, linewidth=line_width, 14 | marker='o', markersize=marker_size, color="black") 15 | ax.set_xlabel('Interpolation coefficient') 16 | ax.set_ylabel('Adversarial loss') 17 | # ax.legend(loc='best', prop={'size': 12}) 18 | ax.set_title(plot_name) 19 | plt.savefig('plots/{}.pdf'.format(plot_name), bbox_inches='tight') 20 | plt.close() 21 | 22 | 23 | def histogram_delta(delta, attack, rs_train, rs_attack): 24 | sns.set() 25 | plot_name = 'histogram_delta-attack={}-rs_train={}-rs_attack={}'.format(attack, rs_train, rs_attack) 26 | sns.distplot(delta.flatten().cpu(), kde=False, rug=False, hist_kws={'log': True}) 27 | plt.savefig('plots/{}.pdf'.format(plot_name), bbox_inches='tight') 28 | plt.close() 29 | 30 | -------------------------------------------------------------------------------- /classification_tasks/randomaug.py: -------------------------------------------------------------------------------- 1 | # code in this file is adpated from rpmcruz/autoaugment 2 | # https://github.com/rpmcruz/autoaugment/blob/master/transformations.py 3 | import random 4 | 5 | import PIL, PIL.ImageOps, PIL.ImageEnhance, PIL.ImageDraw 6 | import numpy as np 7 | import torch 8 | from PIL import Image 9 | 10 | 11 | def ShearX(img, v): # [-0.3, 0.3] 12 | assert -0.3 <= v <= 0.3 13 | if random.random() > 0.5: 14 | v = -v 15 | return img.transform(img.size, PIL.Image.AFFINE, (1, v, 0, 0, 1, 0)) 16 | 17 | 18 | def ShearY(img, v): # [-0.3, 0.3] 19 | assert -0.3 <= v <= 0.3 20 | if random.random() > 0.5: 21 | v = -v 22 | return img.transform(img.size, PIL.Image.AFFINE, (1, 0, 0, v, 1, 0)) 23 | 24 | 25 | def TranslateX(img, v): # [-150, 150] => percentage: [-0.45, 0.45] 26 | assert -0.45 <= v <= 0.45 27 | if random.random() > 0.5: 28 | v = -v 29 | v = v * img.size[0] 30 | return img.transform(img.size, PIL.Image.AFFINE, (1, 0, v, 0, 1, 0)) 31 | 32 | 33 | def TranslateXabs(img, v): # [-150, 150] => percentage: [-0.45, 0.45] 34 | assert 0 <= v 35 | if random.random() > 0.5: 36 | v = -v 37 | return img.transform(img.size, PIL.Image.AFFINE, (1, 0, v, 0, 1, 0)) 38 | 39 | 40 | def TranslateY(img, v): # [-150, 150] => percentage: [-0.45, 0.45] 41 | assert -0.45 <= v <= 0.45 42 | if random.random() > 0.5: 43 | v = -v 44 | v = v * img.size[1] 45 | return img.transform(img.size, PIL.Image.AFFINE, (1, 0, 0, 0, 1, v)) 46 | 47 | 48 | def TranslateYabs(img, v): # [-150, 150] => percentage: [-0.45, 0.45] 49 | assert 0 <= v 50 | if random.random() > 0.5: 51 | v = -v 52 | return img.transform(img.size, PIL.Image.AFFINE, (1, 0, 0, 0, 1, v)) 53 | 54 | 55 | def Rotate(img, v): # [-30, 30] 56 | assert -30 <= v <= 30 57 | if random.random() > 0.5: 58 | v = -v 59 | return img.rotate(v) 60 | 61 | 62 | def AutoContrast(img, _): 63 | return PIL.ImageOps.autocontrast(img) 64 | 65 | 66 | def Invert(img, _): 67 | return PIL.ImageOps.invert(img) 68 | 69 | 70 | def Equalize(img, _): 71 | return PIL.ImageOps.equalize(img) 72 | 73 | 74 | def Flip(img, _): # not from the paper 75 | return PIL.ImageOps.mirror(img) 76 | 77 | 78 | def Solarize(img, v): # [0, 256] 79 | assert 0 <= v <= 256 80 | return PIL.ImageOps.solarize(img, v) 81 | 82 | 83 | def SolarizeAdd(img, addition=0, threshold=128): 84 | img_np = np.array(img).astype(np.int_) 85 | img_np = img_np + addition 86 | img_np = np.clip(img_np, 0, 255) 87 | img_np = img_np.astype(np.uint8) 88 | img = Image.fromarray(img_np) 89 | return PIL.ImageOps.solarize(img, threshold) 90 | 91 | 92 | def Posterize(img, v): # [4, 8] 93 | v = int(v) 94 | v = max(1, v) 95 | return PIL.ImageOps.posterize(img, v) 96 | 97 | 98 | def Contrast(img, v): # [0.1,1.9] 99 | assert 0.1 <= v <= 1.9 100 | return PIL.ImageEnhance.Contrast(img).enhance(v) 101 | 102 | 103 | def Color(img, v): # [0.1,1.9] 104 | assert 0.1 <= v <= 1.9 105 | return PIL.ImageEnhance.Color(img).enhance(v) 106 | 107 | 108 | def Brightness(img, v): # [0.1,1.9] 109 | assert 0.1 <= v <= 1.9 110 | return PIL.ImageEnhance.Brightness(img).enhance(v) 111 | 112 | 113 | def Sharpness(img, v): # [0.1,1.9] 114 | assert 0.1 <= v <= 1.9 115 | return PIL.ImageEnhance.Sharpness(img).enhance(v) 116 | 117 | 118 | def Cutout(img, v): # [0, 60] => percentage: [0, 0.2] 119 | assert 0.0 <= v <= 0.2 120 | if v <= 0.: 121 | return img 122 | 123 | v = v * img.size[0] 124 | return CutoutAbs(img, v) 125 | 126 | 127 | def CutoutAbs(img, v): # [0, 60] => percentage: [0, 0.2] 128 | # assert 0 <= v <= 20 129 | if v < 0: 130 | return img 131 | w, h = img.size 132 | x0 = np.random.uniform(w) 133 | y0 = np.random.uniform(h) 134 | 135 | x0 = int(max(0, x0 - v / 2.)) 136 | y0 = int(max(0, y0 - v / 2.)) 137 | x1 = min(w, x0 + v) 138 | y1 = min(h, y0 + v) 139 | 140 | xy = (x0, y0, x1, y1) 141 | color = (125, 123, 114) 142 | # color = (0, 0, 0) 143 | img = img.copy() 144 | PIL.ImageDraw.Draw(img).rectangle(xy, color) 145 | return img 146 | 147 | 148 | def SamplePairing(imgs): # [0, 0.4] 149 | def f(img1, v): 150 | i = np.random.choice(len(imgs)) 151 | img2 = PIL.Image.fromarray(imgs[i]) 152 | return PIL.Image.blend(img1, img2, v) 153 | 154 | return f 155 | 156 | 157 | def Identity(img, v): 158 | return img 159 | 160 | 161 | def augment_list(): # 16 oeprations and their ranges 162 | # https://github.com/google-research/uda/blob/master/image/randaugment/policies.py#L57 163 | # l = [ 164 | # (Identity, 0., 1.0), 165 | # (ShearX, 0., 0.3), # 0 166 | # (ShearY, 0., 0.3), # 1 167 | # (TranslateX, 0., 0.33), # 2 168 | # (TranslateY, 0., 0.33), # 3 169 | # (Rotate, 0, 30), # 4 170 | # (AutoContrast, 0, 1), # 5 171 | # (Invert, 0, 1), # 6 172 | # (Equalize, 0, 1), # 7 173 | # (Solarize, 0, 110), # 8 174 | # (Posterize, 4, 8), # 9 175 | # # (Contrast, 0.1, 1.9), # 10 176 | # (Color, 0.1, 1.9), # 11 177 | # (Brightness, 0.1, 1.9), # 12 178 | # (Sharpness, 0.1, 1.9), # 13 179 | # # (Cutout, 0, 0.2), # 14 180 | # # (SamplePairing(imgs), 0, 0.4), # 15 181 | # ] 182 | 183 | # https://github.com/tensorflow/tpu/blob/8462d083dd89489a79e3200bcc8d4063bf362186/models/official/efficientnet/autoaugment.py#L505 184 | l = [ 185 | (AutoContrast, 0, 1), 186 | (Equalize, 0, 1), 187 | (Invert, 0, 1), 188 | (Rotate, 0, 30), 189 | (Posterize, 0, 4), 190 | (Solarize, 0, 256), 191 | (SolarizeAdd, 0, 110), 192 | (Color, 0.1, 1.9), 193 | (Contrast, 0.1, 1.9), 194 | (Brightness, 0.1, 1.9), 195 | (Sharpness, 0.1, 1.9), 196 | (ShearX, 0., 0.3), 197 | (ShearY, 0., 0.3), 198 | (CutoutAbs, 0, 40), 199 | (TranslateXabs, 0., 100), 200 | (TranslateYabs, 0., 100), 201 | ] 202 | 203 | return l 204 | 205 | 206 | class Lighting(object): 207 | """Lighting noise(AlexNet - style PCA - based noise)""" 208 | 209 | def __init__(self, alphastd, eigval, eigvec): 210 | self.alphastd = alphastd 211 | self.eigval = torch.Tensor(eigval) 212 | self.eigvec = torch.Tensor(eigvec) 213 | 214 | def __call__(self, img): 215 | if self.alphastd == 0: 216 | return img 217 | 218 | alpha = img.new().resize_(3).normal_(0, self.alphastd) 219 | rgb = self.eigvec.type_as(img).clone() \ 220 | .mul(alpha.view(1, 3).expand(3, 3)) \ 221 | .mul(self.eigval.view(1, 3).expand(3, 3)) \ 222 | .sum(1).squeeze() 223 | 224 | return img.add(rgb.view(3, 1, 1).expand_as(img)) 225 | 226 | 227 | class CutoutDefault(object): 228 | """ 229 | Reference : https://github.com/quark0/darts/blob/master/cnn/utils.py 230 | """ 231 | def __init__(self, length): 232 | self.length = length 233 | 234 | def __call__(self, img): 235 | h, w = img.size(1), img.size(2) 236 | mask = np.ones((h, w), np.float32) 237 | y = np.random.randint(h) 238 | x = np.random.randint(w) 239 | 240 | y1 = np.clip(y - self.length // 2, 0, h) 241 | y2 = np.clip(y + self.length // 2, 0, h) 242 | x1 = np.clip(x - self.length // 2, 0, w) 243 | x2 = np.clip(x + self.length // 2, 0, w) 244 | 245 | mask[y1: y2, x1: x2] = 0. 246 | mask = torch.from_numpy(mask) 247 | mask = mask.expand_as(img) 248 | img *= mask 249 | return img 250 | 251 | 252 | class RandAugment: 253 | def __init__(self, n, m): 254 | self.n = n 255 | self.m = m # [0, 30] 256 | self.augment_list = augment_list() 257 | 258 | def __call__(self, img): 259 | ops = random.choices(self.augment_list, k=self.n) 260 | for op, minval, maxval in ops: 261 | val = (float(self.m) / 30) * float(maxval - minval) + minval 262 | img = op(img, val) 263 | 264 | return img 265 | 266 | -------------------------------------------------------------------------------- /classification_tasks/sam.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import utils 3 | import copy 4 | 5 | 6 | class SAM(torch.optim.Optimizer): 7 | """ 8 | Implementation of SAM is based on https://github.com/davda54/sam/blob/main/sam.py. 9 | """ 10 | def __init__(self, params, base_optimizer, rho, sam_no_grad_norm, only_sam_step_size=False, **kwargs): 11 | defaults = dict(rho=rho, **kwargs) 12 | super(SAM, self).__init__(params, defaults) 13 | 14 | self.base_optimizer = base_optimizer(self.param_groups, **kwargs) 15 | self.param_groups = self.base_optimizer.param_groups 16 | self.sam_no_grad_norm = sam_no_grad_norm 17 | # `only_sam_step_size`: we modify the step size (grad(w_sam) / grad(w) multiplicative factor) but keep the 18 | # original gradient direction as in the point `w`. 19 | self.only_sam_step_size = only_sam_step_size 20 | self.grad_norm_w, self.grad_norm_w_sam = None, None 21 | self.grad_w = dict() 22 | 23 | def _grad_norm(self): 24 | # put everything on the same device, in case of model parallelism 25 | shared_device = self.param_groups[0]["params"][0].device 26 | norm = torch.norm( 27 | torch.stack([ 28 | p.grad.norm(p=2).to(shared_device) 29 | for group in self.param_groups for p in group["params"] 30 | if p.grad is not None 31 | ]), 32 | p=2 33 | ) 34 | return norm 35 | 36 | @torch.no_grad() 37 | def first_step(self): 38 | """ 39 | At the beginning of `first_step()`, the grads are at point `w`. 40 | Then this method updates the main model parameters in opt.param_groups[0]['params'] 41 | """ 42 | grad_norm = 1 if self.sam_no_grad_norm else self._grad_norm() 43 | self.grad_norm_w = self._grad_norm() 44 | # print(self.grad_norm_w) 45 | for group in self.param_groups: # by default, there is only 1 param group 46 | # standard SGD optimizer contains the following keys: 47 | # ['params', 'lr', 'momentum', 'dampening', 'weight_decay', 'nesterov'] 48 | scale = group["rho"] / (grad_norm + 1e-12) 49 | 50 | for param in group["params"]: # group["params"] is a list with 56 params 51 | if param.grad is None: continue 52 | delta_w = scale * param.grad 53 | param.add_(delta_w) # climb to the local maximum "w + e(w)" 54 | # by default, opt.state==defaultdict(, {}) but we can use it to store the SAM's delta 55 | self.state[param]["delta_w"] = delta_w # only the ref to the `param.data` is used as the key 56 | self.grad_w[param] = param.grad.clone() # store it to apply on 2nd step (if only_sam_step_size==True) 57 | 58 | self.zero_grad() # and we zero out the first grads (since we've already stored them) 59 | 60 | @torch.no_grad() 61 | def second_step(self): 62 | """ 63 | At the beginning of `second_step()`, the grads are at point `w + delta`. 64 | """ 65 | for group in self.param_groups: 66 | for param in group["params"]: 67 | if param.grad is None: continue 68 | param.sub_(self.state[param]["delta_w"]) # get back to `w` from `w + delta` 69 | 70 | if self.only_sam_step_size: # put the original gradient and change only the step size 71 | self.grad_norm_w_sam = self._grad_norm() 72 | for group in self.param_groups: 73 | for param in group['params']: 74 | param.grad = self.grad_w[param] * self.grad_norm_w_sam / self.grad_norm_w 75 | # param.grad = param.grad * self.grad_norm_w / self.grad_norm_w_sam 76 | 77 | @torch.no_grad() 78 | def step(self, closure=None): 79 | self.base_optimizer.step() 80 | 81 | 82 | def zero_init_delta_dict(delta_dict, rho): 83 | for param in delta_dict: 84 | delta_dict[param] = torch.zeros_like(param).cuda() 85 | 86 | delta_norm = torch.cat([delta_param.flatten() for delta_param in delta_dict.values()]).norm() 87 | for param in delta_dict: 88 | delta_dict[param] *= rho / delta_norm 89 | 90 | return delta_dict 91 | 92 | def random_init_on_sphere_delta_dict(delta_dict, rho): 93 | for param in delta_dict: 94 | delta_dict[param] = torch.randn_like(param).cuda() 95 | 96 | delta_norm = torch.cat([delta_param.flatten() for delta_param in delta_dict.values()]).norm() 97 | for param in delta_dict: 98 | delta_dict[param] *= rho / delta_norm 99 | 100 | return delta_dict 101 | 102 | 103 | def weight_ascent_step(model, f, orig_param_dict, delta_dict, step_size, rho, layer_name_pattern='all', no_grad_norm=False, verbose=False): 104 | utils.zero_grad(model) 105 | obj = f(model) # grads are accumulated 106 | obj.backward() 107 | 108 | grad_norm = utils.get_flat_grad(model).norm() 109 | if verbose: 110 | print('obj={:.3f}, grad_norm={:.3f}'.format(obj, grad_norm)) 111 | 112 | for param_name, param in model.named_parameters(): 113 | if layer_name_pattern == 'all' or layer_name_pattern in param_name: 114 | if no_grad_norm: 115 | delta_dict[param] += step_size * param.grad 116 | else: 117 | delta_dict[param] += step_size / (grad_norm + 1e-7) * param.grad 118 | 119 | delta_norm = torch.cat([delta_param.flatten() for delta_param in delta_dict.values()]).norm() 120 | if delta_norm > rho: 121 | for param in delta_dict: 122 | delta_dict[param] *= rho / delta_norm 123 | 124 | # now apply the (potentially) scaled perturbation to modify the weight 125 | for param in model.parameters(): 126 | param.data = orig_param_dict[param] + delta_dict[param] 127 | 128 | utils.zero_grad(model) 129 | return delta_dict 130 | 131 | 132 | def perturb_weights_sam(model, f, rho, step_size, n_iters, no_grad_norm, rand_init=False, verbose=False): 133 | delta_dict = {param: torch.zeros_like(param) for param in model.parameters()} 134 | 135 | # random init on the sphere of radius `rho` 136 | if rand_init: 137 | delta_dict = random_init_on_sphere_delta_dict(delta_dict, rho) 138 | for param in model.parameters(): 139 | param.data += delta_dict[param] 140 | 141 | orig_param_dict = {param: param.clone() for param in model.parameters()} 142 | 143 | for iter in range(n_iters): 144 | delta_dict = weight_ascent_step(model, f, orig_param_dict, delta_dict, step_size, rho, no_grad_norm=no_grad_norm, verbose=False) 145 | 146 | return delta_dict 147 | 148 | -------------------------------------------------------------------------------- /classification_tasks/utils.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import torch 3 | from contextlib import contextmanager 4 | 5 | 6 | logger = logging.getLogger(__name__) 7 | logging.basicConfig( 8 | format='[%(asctime)s %(filename)s %(name)s %(levelname)s] - %(message)s', 9 | datefmt='%Y/%m/%d %H:%M:%S', 10 | level=logging.DEBUG) 11 | 12 | 13 | def clamp(X, l, u, cuda=True): 14 | if type(l) is not torch.Tensor: 15 | if cuda: 16 | l = torch.cuda.FloatTensor(1).fill_(l) 17 | else: 18 | l = torch.FloatTensor(1).fill_(l) 19 | if type(u) is not torch.Tensor: 20 | if cuda: 21 | u = torch.cuda.FloatTensor(1).fill_(u) 22 | else: 23 | u = torch.FloatTensor(1).fill_(u) 24 | return torch.max(torch.min(X, u), l) 25 | 26 | 27 | def configure_logger(model_name, debug): 28 | logging.basicConfig(format='%(message)s') # , level=logging.DEBUG) 29 | logger = logging.getLogger() 30 | logger.handlers = [] # remove the default logger 31 | 32 | # add a new logger for stdout 33 | formatter = logging.Formatter('%(message)s') 34 | ch = logging.StreamHandler() 35 | ch.setFormatter(formatter) 36 | ch.setLevel(logging.DEBUG) 37 | logger.addHandler(ch) 38 | 39 | if not debug: 40 | # add a new logger to a log file 41 | logger.addHandler(logging.FileHandler('logs/{}.log'.format(model_name))) 42 | 43 | return logger 44 | 45 | 46 | def get_random_delta(shape, eps, at_norm, requires_grad=True): 47 | delta = torch.zeros(shape).cuda() 48 | if at_norm == 'l2': # uniform from the hypercube 49 | delta.normal_() 50 | delta /= (delta**2).sum([1, 2, 3], keepdim=True)**0.5 51 | elif at_norm == 'linf': # uniform on the sphere 52 | delta.uniform_(-eps, eps) 53 | else: 54 | raise ValueError('wrong at_norm') 55 | delta.requires_grad = requires_grad 56 | return delta 57 | 58 | 59 | def project_lp(img, at_norm, eps): 60 | if at_norm == 'l2': # uniform on the sphere 61 | l2_norms = (img ** 2).sum([1, 2, 3], keepdim=True) ** 0.5 62 | img_proj = img * torch.min(eps/l2_norms, torch.ones_like(l2_norms)) # if eps>l2_norms => multiply by 1 63 | elif at_norm == 'linf': # uniform from the hypercube 64 | img_proj = clamp(img, -eps, eps) 65 | else: 66 | raise ValueError('wrong at_norm') 67 | return img_proj 68 | 69 | 70 | def update_metrics(metrics_dict, metrics_values, metrics_names): 71 | assert len(metrics_values) == len(metrics_names) 72 | for metric_value, metric_name in zip(metrics_values, metrics_names): 73 | metrics_dict[metric_name].append(metric_value) 74 | return metrics_dict 75 | 76 | 77 | @contextmanager 78 | def nullcontext(enter_result=None): 79 | yield enter_result 80 | 81 | 82 | def get_flat_grad(model): 83 | return torch.cat([p.grad.flatten() for p in model.parameters() if p.grad is not None]) 84 | 85 | 86 | def zero_grad(model): 87 | for p in model.parameters(): 88 | if p.grad is not None: 89 | p.grad.zero_() 90 | 91 | 92 | def eval_f_val_grad(model, f): 93 | zero_grad(model) 94 | 95 | obj = f(model) # grads are accumulated 96 | obj.backward() 97 | grad_norm = get_flat_grad(model).norm() 98 | 99 | zero_grad(model) 100 | return obj.detach(), grad_norm.detach() 101 | 102 | -------------------------------------------------------------------------------- /contrastive_text_image_learning/data.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import tensorflow_datasets as tfds 3 | from functools import partial 4 | import numpy as np 5 | import jax.numpy as jnp 6 | 7 | 8 | IMAGE_SIZE = 384 9 | 10 | 11 | def coco_image_processor(feature_dict, image_size): 12 | image = feature_dict['image'] 13 | image = tf.cast(image, tf.float32) 14 | image = tf.image.resize(image, image_size) 15 | image /= 255.0 16 | feature_dict['image'] = image 17 | return feature_dict 18 | 19 | 20 | COCO_BUILDER = None 21 | def get_coco_dataset_iter(split='train', image_size=(IMAGE_SIZE, IMAGE_SIZE), shuffle=True): 22 | global COCO_BUILDER 23 | if COCO_BUILDER is None: 24 | COCO_BUILDER = tfds.builder('coco_captions') 25 | COCO_BUILDER.download_and_prepare() 26 | 27 | ds = COCO_BUILDER.as_dataset(split=split).map( 28 | partial(coco_image_processor, image_size=image_size) 29 | ) 30 | if shuffle: 31 | ds = ds.shuffle(10000) 32 | 33 | return ds.as_numpy_iterator() 34 | 35 | 36 | def get_batch_iter_images(ds, total_batch_size): 37 | while True: 38 | image_inputs = [] 39 | for _, d in zip(range(total_batch_size), ds): 40 | image_inputs.append(d['image']) 41 | if len(image_inputs) != total_batch_size: 42 | break 43 | else: 44 | batch = {} 45 | batch['image'] = np.stack(image_inputs, axis=0) 46 | batch = {k: jnp.array(v) for k, v in batch.items()} 47 | yield batch 48 | 49 | 50 | def get_batch_iter(ds, tokenizer, max_text_length, total_batch_size, rand): 51 | while True: 52 | text_inputs = [] 53 | image_inputs = [] 54 | for _, d in zip(range(total_batch_size), ds): 55 | texts = d['captions']['text'] 56 | if rand: 57 | text = texts[rand.randint(len(texts))] 58 | else: 59 | text = texts[0] 60 | text_inputs.append(str(text)) 61 | # image_inputs.append(np.swapaxes(d['image'], 0, 2)) 62 | image_inputs.append(d['image']) 63 | if len(text_inputs) != total_batch_size: 64 | break 65 | else: 66 | batch = tokenizer( 67 | text_inputs, 68 | padding='max_length', 69 | max_length=max_text_length, 70 | truncation=True, 71 | return_tensors='np', 72 | ) 73 | batch['image'] = np.stack(image_inputs, axis=0) 74 | batch = {k: jnp.array(v) for k, v in batch.items()} 75 | yield batch 76 | 77 | 78 | def create_split(dataset_builder, batch_size, dtype=tf.float32, 79 | image_size=IMAGE_SIZE, cache=False): 80 | """Creates a split from the dataset using TensorFlow Datasets. 81 | Args: 82 | dataset_builder: TFDS dataset builder for ImageNet. 83 | batch_size: the batch size returned by the data pipeline. 84 | train: Whether to load the train or evaluation split. 85 | dtype: data type of the image. 86 | image_size: The target size of the images. 87 | cache: Whether to cache the dataset. 88 | Returns: 89 | A `tf.data.Dataset`. 90 | """ 91 | options = tf.data.Options() 92 | options.experimental_threading.private_threadpool_size = 48 93 | ds = ds.with_options(options) 94 | 95 | if cache: 96 | ds = ds.cache() 97 | 98 | ds = ds.repeat() 99 | ds = ds.shuffle(16 * batch_size, seed=0) 100 | 101 | ds = ds.map(decode_example, num_parallel_calls=tf.data.experimental.AUTOTUNE) 102 | ds = ds.batch(batch_size, drop_remainder=True) 103 | 104 | if not train: 105 | ds = ds.repeat() 106 | 107 | ds = ds.prefetch(10) 108 | 109 | return ds 110 | 111 | -------------------------------------------------------------------------------- /contrastive_text_image_learning/eval.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 The Flax Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Flax implementation of ResNet V1.""" 16 | 17 | # See issue #620. 18 | # pytype: disable=wrong-arg-count 19 | 20 | import logging 21 | import os 22 | logging.getLogger('tensorflow').disabled = True 23 | logging.disable(logging.WARNING) # key line to disable the annoying warnings 24 | os.environ["TF_FORCE_GPU_ALLOW_GROWTH"] = 'true' 25 | os.environ['XLA_PYTHON_CLIENT_PREALLOCATE'] = 'false' 26 | # os.environ['XLA_PYTHON_CLIENT_MEM_FRACTION'] = '0.4' 27 | import sys 28 | sys.path.insert(0, 'vision_transformer') # so that the local version of vit_jax is directly available 29 | 30 | from functools import partial 31 | from typing import Any, Callable, Sequence, Tuple 32 | import jax.numpy as jnp 33 | import jax 34 | import argparse 35 | import tensorflow as tf 36 | from flax import linen as nn 37 | from flax.training import checkpoints as flax_checkpoints 38 | # from flax.training.common_utils import shard, shard_prng_key 39 | from transformers import FlaxBertModel, AutoConfig, AutoTokenizer 40 | from models import ResNet18, get_image_model 41 | from data import IMAGE_SIZE, get_coco_dataset_iter 42 | from utils import calc_test_metrics_posthoc, compute_weight_matrix_rank 43 | 44 | 45 | tf.get_logger().setLevel('INFO') 46 | 47 | 48 | if __name__ == '__main__': 49 | parser = argparse.ArgumentParser(description='Description of your program') 50 | parser.add_argument('--bs', type=int, default=128, help='Batch size (128 = 10 GB is used)') 51 | parser.add_argument('--n_eval_batches', type=int, default=10, help='N batches') 52 | parser.add_argument('--projection_dim', type=int, default=768, help='Output feature dimension') 53 | parser.add_argument('--bottleneck_dim', type=int, default=-1, help='Linear bottleneck dimension') 54 | parser.add_argument('--model_path', type=str, default='', help='Model path') 55 | parser.add_argument('--split', type=str, default='test', help='Model path') 56 | parser.add_argument('--return_layer', type=int, default=-1, help='Which layer to return for GeLU and MLP output') 57 | args = parser.parse_args() 58 | 59 | 60 | max_text_length = 32 61 | text_model_checkpoint = 'bert-base-cased' 62 | 63 | # import ipdb;ipdb.set_trace() 64 | text_config = AutoConfig.from_pretrained(text_model_checkpoint) 65 | text_config.attention_probs_dropout_prob = 0 66 | text_config.hidden_dropout_prob = 0 67 | text_config._name_or_path = None 68 | text_model = FlaxBertModel(config=text_config) 69 | tokenizer = AutoTokenizer.from_pretrained(text_model_checkpoint) 70 | 71 | if args.bottleneck_dim > 0: 72 | text_proj = nn.Sequential([nn.Dense(args.bottleneck_dim), nn.Dense(args.projection_dim)]) 73 | image_proj = nn.Sequential([nn.Dense(args.bottleneck_dim), nn.Dense(args.projection_dim)]) 74 | else: 75 | text_proj = nn.Dense(args.projection_dim) 76 | image_proj = nn.Dense(args.projection_dim) 77 | 78 | rng = jax.random.PRNGKey(0) 79 | image_model_rng, rng = jax.random.split(rng) 80 | text_proj_rng, rng = jax.random.split(rng) 81 | 82 | image_model, image_params = get_image_model(random_init_image=False) 83 | image_state = None 84 | 85 | image_proj_params = image_proj.init(image_model_rng, jnp.ones((image_model.hidden_size, ))) 86 | text_proj_params = text_proj.init(text_proj_rng, jnp.ones((text_config.hidden_size,))) 87 | 88 | models = [image_model, text_model, image_proj, text_proj] 89 | # params = [image_params, text_model.params, image_proj_params, text_proj_params] 90 | 91 | params_ckpt = flax_checkpoints.restore_checkpoint(args.model_path, target=None)['0'] 92 | params = [params_ckpt['0'], params_ckpt['1'], params_ckpt['2'], params_ckpt['3']] 93 | 94 | # TODO: extract features from the BERT encoder as well (in a similar way; need a hacked HF repo here?) 95 | metrics = {} 96 | metrics = calc_test_metrics_posthoc( 97 | models, 98 | params, 99 | image_state, 100 | [rng, rng], 101 | tokenizer, 102 | max_text_length, 103 | args.bs, 104 | args.n_eval_batches, 105 | get_coco_dataset_iter(split=args.split, shuffle=False), 106 | args.split, 107 | args.return_layer 108 | ) 109 | metrics['weight_matrix_ranks'] = compute_weight_matrix_rank(params, pc_threshold=0.99) 110 | metrics['rho'] = args.model_path.split('rho=')[1].split('_')[0] 111 | metrics['bottleneck_dim'] = args.bottleneck_dim 112 | metrics['split'] = args.split # train or test 113 | metrics['return_layer'] = args.return_layer 114 | 115 | print('{},'.format(metrics)) 116 | 117 | -------------------------------------------------------------------------------- /contrastive_text_image_learning/eval_imagenet.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 The Flax Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Flax implementation of ResNet V1.""" 16 | 17 | # See issue #620. 18 | # pytype: disable=wrong-arg-count 19 | 20 | import logging 21 | import os 22 | logging.getLogger('tensorflow').disabled = True 23 | logging.disable(logging.WARNING) # key line to disable the annoying warnings 24 | os.environ["TF_FORCE_GPU_ALLOW_GROWTH"] = 'true' 25 | os.environ['XLA_PYTHON_CLIENT_PREALLOCATE'] = 'false' 26 | import sys 27 | sys.path.insert(0, 'vision_transformer') # so that the local version of vit_jax is directly available 28 | 29 | from functools import partial 30 | from typing import Any, Callable, Sequence, Tuple 31 | import jax.numpy as jnp 32 | import jax 33 | import argparse 34 | import tensorflow as tf 35 | import numpy as np 36 | from flax import linen as nn 37 | # from flax.training.common_utils import shard, shard_prng_key 38 | from data import IMAGE_SIZE, get_coco_dataset_iter, get_batch_iter_images 39 | from utils import calc_pca_rank, gelu 40 | from vit_jax import checkpoint as vit_checkpoint 41 | from vit_jax import models as vit_models 42 | from vit_jax.configs import models as vit_models_config 43 | 44 | 45 | tf.get_logger().setLevel('INFO') 46 | 47 | 48 | if __name__ == '__main__': 49 | parser = argparse.ArgumentParser(description='Description of your program') 50 | parser.add_argument('--bs', type=int, default=128, help='Batch size (128 = 10 GB is used)') 51 | parser.add_argument('--n_eval_batches', type=int, default=10, help='N batches') 52 | parser.add_argument('--model_path', type=str, default='', help='Model path') 53 | parser.add_argument('--split', type=str, default='test', help='Split: train or test') 54 | parser.add_argument('--return_layer', type=int, default=-1, help='Which layer to return') 55 | parser.add_argument('--avg_tokens', action='store_true') 56 | args = parser.parse_args() 57 | 58 | config_name = args.model_path[:-4].replace('-224', '').split('/')[-1] # e.g., ViT-B_16, ViT-L_32, etc 59 | model_config = vit_models_config.MODEL_CONFIGS[config_name] 60 | resolution = 224 if 'sam/' in args.model_path or '224' in args.model_path or 'Mixer' in args.model_path else 384 # seems like 224 (and not 384) since this is what's mentioned in Chen et al. (2021) 61 | 62 | if config_name.startswith('Mixer'): 63 | model = vit_models.MlpMixer(num_classes=None, **model_config) 64 | else: 65 | model = vit_models.VisionTransformer(num_classes=None, **model_config) 66 | 67 | params = vit_checkpoint.load(args.model_path) 68 | 69 | metrics = {'n_eval_batches': args.n_eval_batches, 'model_path': args.model_path, 'return_layer': args.return_layer} 70 | pc_threshold = 0.99 71 | pc_threshold_name = str(pc_threshold).replace('0.', '') 72 | 73 | preatt_all, preact_all, mlp_all, res_all = [], [], [], [] 74 | ds = get_imagenet_dataset_iter(split=args.split, image_size=(resolution, resolution), shuffle=False) 75 | for _, batch in zip( 76 | range(args.n_eval_batches), 77 | get_batch_iter_images(ds, args.bs) 78 | ): 79 | preatt, preact, mlp, res = model.apply({'params': params}, 2*(batch['image'] - 0.5), train=False, return_acts=True, return_layer=args.return_layer) 80 | preatt_all.append(np.asarray(preatt[:, 0, :] if not args.avg_tokens else preatt.mean(1))) 81 | preact_all.append(np.asarray(preact[:, 0, :] if not args.avg_tokens else preact.mean(1))) 82 | mlp_all.append(np.asarray(mlp[:, 0, :] if not args.avg_tokens else mlp.mean(1))) 83 | res_all.append(np.asarray(res[:, 0, :] if not args.avg_tokens else res.mean(1))) 84 | 85 | preatt_all = np.concatenate(preatt_all, axis=0) 86 | preact_all = np.concatenate(preact_all, axis=0) 87 | mlp_all = np.concatenate(mlp_all, axis=0) 88 | res_all = np.concatenate(res_all, axis=0) 89 | acts_all = gelu(preact_all) 90 | premlp_all = res_all - mlp_all 91 | att_all = premlp_all - preatt_all 92 | 93 | metrics[f'image_rank_{pc_threshold_name}p_preatt'] = calc_pca_rank(preatt_all, [pc_threshold])[0] 94 | metrics[f'image_rank_{pc_threshold_name}p_preact'] = calc_pca_rank(preact_all, [pc_threshold])[0] 95 | metrics[f'image_rank_{pc_threshold_name}p_mlp'] = calc_pca_rank(mlp_all, [pc_threshold])[0] 96 | metrics[f'image_rank_{pc_threshold_name}p_res'] = calc_pca_rank(res_all, [pc_threshold])[0] 97 | metrics[f'image_rank_{pc_threshold_name}p_acts'] = calc_pca_rank(acts_all, [pc_threshold])[0] 98 | metrics[f'image_rank_{pc_threshold_name}p_premlp'] = calc_pca_rank(premlp_all, [pc_threshold])[0] 99 | metrics[f'image_rank_{pc_threshold_name}p_att'] = calc_pca_rank(att_all, [pc_threshold])[0] 100 | 101 | print('{},'.format(metrics)) 102 | 103 | 104 | -------------------------------------------------------------------------------- /contrastive_text_image_learning/main.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 The Flax Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Flax implementation of ResNet V1.""" 16 | 17 | # See issue #620. 18 | # pytype: disable=wrong-arg-count 19 | 20 | import sys 21 | sys.path.insert(0, 'vision_transformer') 22 | 23 | import os 24 | os.environ["TF_FORCE_GPU_ALLOW_GROWTH"] = 'true' 25 | # os.environ['XLA_PYTHON_CLIENT_PREALLOCATE'] = 'false' # uncomment for dynamic memory allocation 26 | 27 | from functools import partial 28 | from typing import Any, Callable, Sequence, Tuple 29 | import jax.numpy as jnp 30 | import jax 31 | import argparse 32 | import collections 33 | import torch 34 | import tensorflow as tf 35 | import optax 36 | import numpy as np 37 | import tqdm 38 | import flax 39 | from flax import linen as nn 40 | from flax.training import checkpoints as flax_checkpoints 41 | # from flax.training.common_utils import shard, shard_prng_key 42 | from transformers import FlaxBertModel, AutoConfig, AutoTokenizer 43 | from vit_jax import input_pipeline 44 | from models import ResNet18, get_image_model 45 | from data import IMAGE_SIZE, get_coco_dataset_iter, get_batch_iter 46 | from utils import calc_test_metrics, get_loss_fn, get_gradnorm_reg, sam_loss_grad_fn 47 | 48 | 49 | def run( 50 | learning_rate=0.001, 51 | sam_rho=0.0, 52 | grad_norm_rho=0.0, 53 | seed=0, 54 | model_checkpoint='bert-base-cased', 55 | num_train_epochs=1, 56 | optimizer_name='adam', 57 | do_gradnorm_squared=False, 58 | total_batch_size=64, 59 | max_text_length=32, 60 | n_test_batches=1, 61 | temperature=1.0, 62 | projection_dim=768, 63 | bottleneck_dim=-1, 64 | export_dir='.', 65 | run_id='', 66 | random_init_image=False, 67 | random_init_text=False, 68 | ): 69 | tb_dir = os.path.join(export_dir, 'tb', run_id) 70 | if not tf.io.gfile.isdir(tb_dir): 71 | tf.io.gfile.makedirs(tb_dir) 72 | tb_writer = tf.summary.create_file_writer(tb_dir) 73 | 74 | text_config = AutoConfig.from_pretrained(model_checkpoint) 75 | text_config.attention_probs_dropout_prob = 0 76 | text_config.hidden_dropout_prob = 0 77 | if random_init_text: 78 | text_model = FlaxBertModel(config=text_config) 79 | else: 80 | text_model = FlaxBertModel.from_pretrained(model_checkpoint, config=text_config) 81 | tokenizer = AutoTokenizer.from_pretrained(model_checkpoint) 82 | 83 | if bottleneck_dim > 0: 84 | text_proj = nn.Sequential([nn.Dense(bottleneck_dim), nn.Dense(projection_dim)]) 85 | image_proj = nn.Sequential([nn.Dense(bottleneck_dim), nn.Dense(projection_dim)]) 86 | else: 87 | text_proj = nn.Dense(projection_dim) 88 | image_proj = nn.Dense(projection_dim) 89 | 90 | rng = jax.random.PRNGKey(seed) 91 | image_model_rng, rng = jax.random.split(rng) 92 | text_proj_rng, rng = jax.random.split(rng) 93 | 94 | image_model, image_params = get_image_model(random_init_image) 95 | image_state = None 96 | 97 | # dropout_rngs = jax.random.split(rng, jax.local_device_count()) 98 | text_proj_params = text_proj.init(text_proj_rng, jnp.ones((text_config.hidden_size,))) 99 | image_proj_params = image_proj.init(image_model_rng, jnp.ones((image_model.hidden_size, ))) 100 | 101 | models = [image_model, text_model, image_proj, text_proj] 102 | params = [image_params, text_model.params, image_proj_params, text_proj_params] 103 | 104 | num_train_samples = 82783 105 | n_iters_per_epoch = num_train_samples // total_batch_size 106 | n_total_iters = num_train_epochs * n_iters_per_epoch 107 | 108 | lr_scheduler = optax.cosine_decay_schedule(learning_rate, decay_steps=n_total_iters, alpha=0.0) 109 | if optimizer_name == 'adam': 110 | tx = optax.adam(learning_rate=lr_scheduler) 111 | elif optimizer_name == 'sgd': 112 | tx = optax.sgd(learning_rate=lr_scheduler) 113 | else: 114 | assert False 115 | 116 | opt_state = tx.init(params) 117 | loss_fn = get_loss_fn(models) 118 | 119 | assert not (sam_rho > 0 and grad_norm_rho > 0) 120 | if sam_rho > 0.0: 121 | loss_grad_fn = sam_loss_grad_fn(loss_fn, sam_rho=sam_rho, has_aux=True) 122 | elif grad_norm_rho > 0.0: 123 | loss_grad_fn = jax.value_and_grad( 124 | get_gradnorm_reg( 125 | loss_fn, 126 | rho=grad_norm_rho, 127 | do_gradnorm_squared=do_gradnorm_squared, 128 | has_aux=True, 129 | ), 130 | has_aux=True, 131 | ) 132 | else: 133 | loss_grad_fn = jax.value_and_grad(loss_fn, has_aux=True) 134 | 135 | def train_step(params, batch, image_state, opt_state, dropout_rng): 136 | left_dropout_rng, right_dropout_rng, new_dropout_rng = jax.random.split( 137 | dropout_rng, 3 138 | ) 139 | 140 | (loss_val, image_state), grads = loss_grad_fn( 141 | params, 142 | batch, 143 | image_state, 144 | dropout_rngs=[left_dropout_rng, right_dropout_rng], 145 | train=True, 146 | T=temperature, 147 | ) 148 | # grads = jax.lax.pmean(grads, "batch") 149 | 150 | updates, opt_state = tx.update(grads, opt_state) 151 | params = optax.apply_updates(params, updates) 152 | 153 | # metrics = jax.lax.pmean({"loss": loss_val}, axis_name="batch") 154 | metrics = {'loss': loss_val} 155 | return params, new_dropout_rng, image_state, metrics 156 | 157 | # per_device_batch_size = 4 158 | # total_batch_size = per_device_batch_size * jax.local_device_count() 159 | 160 | rand = np.random.RandomState(seed) 161 | 162 | # training loop 163 | metrics = collections.defaultdict(list) 164 | step = 0 165 | 166 | for _, epoch in enumerate( 167 | tqdm.tqdm( 168 | range(1, num_train_epochs + 1), 169 | desc=f'Epoch ...', 170 | position=0, 171 | leave=True, 172 | ) 173 | ): 174 | # train 175 | with tqdm.tqdm( 176 | total=n_iters_per_epoch, 177 | desc='Training...', 178 | leave=False, 179 | ) as progress_bar_train: 180 | ds = get_coco_dataset_iter(split='train', shuffle=True) 181 | for batch in get_batch_iter( 182 | ds=ds, 183 | tokenizer=tokenizer, 184 | max_text_length=max_text_length, 185 | total_batch_size=total_batch_size, 186 | rand=rand, 187 | ): # TODO: get_batch_iter takes a few sec! 188 | params, rng, image_state, train_metrics = train_step( 189 | params, batch, image_state, opt_state, rng 190 | ) # TODO: first run takes a long time and only 1 CPU core is occupied at first 191 | # params, dropout_rngs, image_state = parallel_train_step(params, shard(batch), shard(image_state), opt_state, dropout_rngs) 192 | # train_loss_val = round(flax.jax_utils.unreplicate(train_metrics)['loss'].item(), 3) 193 | 194 | loss_val = train_metrics['loss'].item() 195 | metrics['train_loss'].append(loss_val) 196 | if tb_writer: 197 | with tb_writer.as_default(): 198 | tf.summary.scalar('train_loss', loss_val, step=step) 199 | 200 | if step % 10 == 0: 201 | print('[step={}] train loss {:.3f}'.format(step, np.mean(metrics['train_loss'][-5:]))) 202 | 203 | if step % 100 == 0: 204 | m_test = calc_test_metrics( 205 | models, 206 | params, 207 | image_state, 208 | [rng, rng], 209 | loss_fn, 210 | temperature, 211 | tokenizer, 212 | max_text_length, 213 | total_batch_size, 214 | n_test_batches, 215 | get_coco_dataset_iter(split='test', shuffle=False), 216 | 'test', 217 | ) 218 | print('[step={}] {}'.format(step, m_test)) 219 | m = calc_test_metrics( 220 | models, 221 | params, 222 | image_state, 223 | [rng, rng], 224 | loss_fn, 225 | temperature, 226 | tokenizer, 227 | max_text_length, 228 | total_batch_size, 229 | n_test_batches, 230 | get_coco_dataset_iter(split='train', shuffle=False), 231 | 'train', 232 | ) 233 | print('[step={}] {}'.format(step, m)) 234 | m.update(m_test) 235 | for k, v in m.items(): 236 | metrics[k].append(v) 237 | if tb_writer: 238 | with tb_writer.as_default(): 239 | for k, v in m.items(): 240 | tf.summary.scalar(k, v, step=step) 241 | 242 | progress_bar_train.update(1) 243 | step += 1 244 | 245 | try: 246 | checkpoint_path = flax_checkpoints.save_checkpoint( 247 | f'{export_dir}/models/{run_id}', (params, opt_state, step), step, overwrite=True, keep=5) 248 | # params, opt_state, initial_step = flax_checkpoints.restore_checkpoint( 249 | # workdir, (params, opt_state, initial_step)) 250 | print('Saved the model at {}'.format(checkpoint_path)) 251 | except: 252 | print('Failed to save the model.') 253 | 254 | for k, v in metrics.items(): 255 | metrics[k] = np.array(v) 256 | return metrics 257 | 258 | 259 | parser = argparse.ArgumentParser(description='Description of your program') 260 | parser.add_argument('--lr', type=float, default=0.0001, help='Learning rate') 261 | parser.add_argument('--rho', type=float, default=0.0, help='Radius of SAM') 262 | parser.add_argument('--epochs', type=int, default=10, help='Epochs') 263 | parser.add_argument('--bs', type=int, default=32, help='Batch size') 264 | parser.add_argument('--projection_dim', type=int, default=768, help='Output feature dimension') 265 | parser.add_argument('--bottleneck_dim', type=int, default=-1, help='Output feature dimension') 266 | parser.add_argument('--random_init_image', action='store_true', help='Use random init instead of a pretrained image model') 267 | parser.add_argument('--random_init_text', action='store_true', help='Use random init instead of a pretrained text model') 268 | parser.add_argument('--run_name', type=str, default='', help='Append this string to run_id') 269 | args = parser.parse_args() 270 | 271 | # used for Tensorboard and for saving the models 272 | run_id = 'lr={}_rho={}_random_init_image={}_random_init_text={}_bottleneck_dim={}'.format( 273 | args.lr, args.rho, args.random_init_image, args.random_init_text, args.bottleneck_dim) 274 | if args.run_name != '': 275 | run_id += '_' + args.run_name 276 | 277 | run( 278 | learning_rate=args.lr, 279 | sam_rho=args.rho, 280 | grad_norm_rho=0., 281 | seed=0, 282 | model_checkpoint="bert-base-cased", 283 | num_train_epochs=args.epochs, 284 | optimizer_name='adam', 285 | do_gradnorm_squared=True, 286 | total_batch_size=args.bs, 287 | max_text_length=32, 288 | n_test_batches=10, 289 | temperature=0.05, 290 | projection_dim=args.projection_dim, 291 | bottleneck_dim=args.bottleneck_dim, 292 | export_dir='/mnt/main-disk', 293 | run_id=run_id, 294 | random_init_image=args.random_init_image, 295 | random_init_text=args.random_init_text, 296 | ) 297 | -------------------------------------------------------------------------------- /contrastive_text_image_learning/models.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 The Flax Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Flax implementation of ResNet V1.""" 16 | 17 | # See issue #620. 18 | # pytype: disable=wrong-arg-count 19 | 20 | from functools import partial 21 | from typing import Any, Callable, Sequence, Tuple 22 | import jax 23 | import jax.numpy as jnp 24 | from flax import linen as nn 25 | 26 | from vit_jax import checkpoint as vit_checkpoint 27 | from vit_jax import models as vit_models 28 | from vit_jax.configs import models as vit_models_config 29 | 30 | ModuleDef = Any 31 | 32 | 33 | def get_image_model(random_init_image): 34 | filename = 'R_Ti_16-i21k-300ep-lr_0.001-aug_none-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_384' 35 | model_config = vit_models_config.AUGREG_CONFIGS[filename.split('-')[0]] 36 | model_config.transformer.dropout_rate = 0.0 37 | path = f'gs://vit_models/augreg/{filename}.npz' 38 | resolution = int(filename.split('_')[-1]) 39 | model = vit_models.VisionTransformer(num_classes=None, **model_config) 40 | if random_init_image: 41 | params = model.init(jax.random.PRNGKey(0), jnp.ones((1, resolution, resolution, 3)), train=False)['params'] 42 | else: 43 | params = vit_checkpoint.load(path) 44 | return model, params 45 | 46 | 47 | class ResNetBlock(nn.Module): 48 | """ResNet block.""" 49 | filters: int 50 | conv: ModuleDef 51 | norm: ModuleDef 52 | act: Callable 53 | strides: Tuple[int, int] = (1, 1) 54 | 55 | @nn.compact 56 | def __call__(self, x,): 57 | residual = x 58 | y = self.conv(self.filters, (3, 3), self.strides)(x) 59 | y = self.norm()(y) 60 | y = self.act(y) 61 | y = self.conv(self.filters, (3, 3))(y) 62 | y = self.norm(scale_init=nn.initializers.zeros)(y) 63 | 64 | if residual.shape != y.shape: 65 | residual = self.conv(self.filters, (1, 1), 66 | self.strides, name='conv_proj')(residual) 67 | residual = self.norm(name='norm_proj')(residual) 68 | 69 | return self.act(residual + y) 70 | 71 | 72 | class BottleneckResNetBlock(nn.Module): 73 | """Bottleneck ResNet block.""" 74 | filters: int 75 | conv: ModuleDef 76 | norm: ModuleDef 77 | act: Callable 78 | strides: Tuple[int, int] = (1, 1) 79 | 80 | @nn.compact 81 | def __call__(self, x): 82 | residual = x 83 | y = self.conv(self.filters, (1, 1))(x) 84 | y = self.norm()(y) 85 | y = self.act(y) 86 | y = self.conv(self.filters, (3, 3), self.strides)(y) 87 | y = self.norm()(y) 88 | y = self.act(y) 89 | y = self.conv(self.filters * 4, (1, 1))(y) 90 | y = self.norm(scale_init=nn.initializers.zeros)(y) 91 | 92 | if residual.shape != y.shape: 93 | residual = self.conv(self.filters * 4, (1, 1), 94 | self.strides, name='conv_proj')(residual) 95 | residual = self.norm(name='norm_proj')(residual) 96 | 97 | return self.act(residual + y) 98 | 99 | 100 | class ResNet(nn.Module): 101 | """ResNetV1.""" 102 | stage_sizes: Sequence[int] 103 | block_cls: ModuleDef 104 | num_classes: int 105 | num_filters: int = 64 106 | dtype: Any = jnp.float32 107 | act: Callable = nn.relu 108 | conv: ModuleDef = nn.Conv 109 | 110 | @nn.compact 111 | def __call__(self, x, train: bool = True): 112 | conv = partial(self.conv, use_bias=False, dtype=self.dtype) 113 | norm = partial(nn.BatchNorm, 114 | use_running_average=not train, 115 | momentum=0.9, 116 | epsilon=1e-5, 117 | dtype=self.dtype) 118 | 119 | x = conv(self.num_filters, (7, 7), (2, 2), 120 | padding=[(3, 3), (3, 3)], 121 | name='conv_init')(x) 122 | x = norm(name='bn_init')(x) 123 | x = nn.relu(x) 124 | x = nn.max_pool(x, (3, 3), strides=(2, 2), padding='SAME') 125 | for i, block_size in enumerate(self.stage_sizes): 126 | for j in range(block_size): 127 | strides = (2, 2) if i > 0 and j == 0 else (1, 1) 128 | x = self.block_cls(self.num_filters * 2 ** i, 129 | strides=strides, 130 | conv=conv, 131 | norm=norm, 132 | act=self.act)(x) 133 | x = jnp.mean(x, axis=(1, 2)) 134 | x = nn.Dense(self.num_classes, dtype=self.dtype)(x) 135 | x = jnp.asarray(x, self.dtype) 136 | return x 137 | 138 | 139 | ResNet18 = partial(ResNet, stage_sizes=[2, 2, 2, 2], 140 | block_cls=ResNetBlock) 141 | ResNet34 = partial(ResNet, stage_sizes=[3, 4, 6, 3], 142 | block_cls=ResNetBlock) 143 | ResNet50 = partial(ResNet, stage_sizes=[3, 4, 6, 3], 144 | block_cls=BottleneckResNetBlock) 145 | ResNet101 = partial(ResNet, stage_sizes=[3, 4, 23, 3], 146 | block_cls=BottleneckResNetBlock) 147 | ResNet152 = partial(ResNet, stage_sizes=[3, 8, 36, 3], 148 | block_cls=BottleneckResNetBlock) 149 | ResNet200 = partial(ResNet, stage_sizes=[3, 24, 36, 3], 150 | block_cls=BottleneckResNetBlock) 151 | 152 | 153 | ResNet18Local = partial(ResNet, stage_sizes=[2, 2, 2, 2], 154 | block_cls=ResNetBlock, conv=nn.ConvLocal) 155 | 156 | 157 | # Used for testing only. 158 | _ResNet1 = partial(ResNet, stage_sizes=[1], block_cls=ResNetBlock) 159 | _ResNet1Local = partial(ResNet, stage_sizes=[1], block_cls=ResNetBlock, 160 | conv=nn.ConvLocal) -------------------------------------------------------------------------------- /contrastive_text_image_learning/vision_transformer/.github/workflows/build.yml: -------------------------------------------------------------------------------- 1 | # This workflow will install Python dependencies, run tests and lint. 2 | # For more information see: https://help.github.com/actions/language-and-framework-guides/using-python-with-github-actions 3 | 4 | name: Build 5 | 6 | on: 7 | pull_request: 8 | branches: 9 | - main 10 | 11 | jobs: 12 | build: 13 | runs-on: ubuntu-latest 14 | strategy: 15 | matrix: 16 | python-version: [3.8] 17 | steps: 18 | - name: Cancel previous 19 | uses: styfle/cancel-workflow-action@0.8.0 20 | with: 21 | access_token: ${{ github.token }} 22 | - uses: actions/checkout@v2 23 | - name: Set up Python ${{ matrix.python-version }} 24 | uses: actions/setup-python@v1 25 | with: 26 | python-version: ${{ matrix.python-version }} 27 | - name: Install dependencies 28 | run: | 29 | pip install . 30 | pip install .[test] 31 | - name: Run pytest 32 | run: | 33 | pytest vit_jax 34 | -------------------------------------------------------------------------------- /contrastive_text_image_learning/vision_transformer/.gitignore: -------------------------------------------------------------------------------- 1 | *.npz 2 | __pycache__/ 3 | /.vscode 4 | /.env 5 | vit_jax.egg-info -------------------------------------------------------------------------------- /contrastive_text_image_learning/vision_transformer/.style.yapf: -------------------------------------------------------------------------------- 1 | [style] 2 | based_on_style: yapf -------------------------------------------------------------------------------- /contrastive_text_image_learning/vision_transformer/CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # How to Contribute 2 | 3 | We'd love to accept your patches and contributions to this project. There are 4 | just a few small guidelines you need to follow. 5 | 6 | ## Contributor License Agreement 7 | 8 | Contributions to this project must be accompanied by a Contributor License 9 | Agreement (CLA). You (or your employer) retain the copyright to your 10 | contribution; this simply gives us permission to use and redistribute your 11 | contributions as part of the project. Head over to 12 | to see your current agreements on file or 13 | to sign a new one. 14 | 15 | You generally only need to submit a CLA once, so if you've already submitted one 16 | (even if it was for a different project), you probably don't need to do it 17 | again. 18 | 19 | ## Code reviews 20 | 21 | All submissions, including submissions by project members, require review. We 22 | use GitHub pull requests for this purpose. Consult 23 | [GitHub Help](https://help.github.com/articles/about-pull-requests/) for more 24 | information on using pull requests. 25 | 26 | ## Community Guidelines 27 | 28 | This project follows 29 | [Google's Open Source Community Guidelines](https://opensource.google/conduct/). 30 | -------------------------------------------------------------------------------- /contrastive_text_image_learning/vision_transformer/LICENSE: -------------------------------------------------------------------------------- 1 | 2 | Apache License 3 | Version 2.0, January 2004 4 | http://www.apache.org/licenses/ 5 | 6 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 7 | 8 | 1. Definitions. 9 | 10 | "License" shall mean the terms and conditions for use, reproduction, 11 | and distribution as defined by Sections 1 through 9 of this document. 12 | 13 | "Licensor" shall mean the copyright owner or entity authorized by 14 | the copyright owner that is granting the License. 15 | 16 | "Legal Entity" shall mean the union of the acting entity and all 17 | other entities that control, are controlled by, or are under common 18 | control with that entity. For the purposes of this definition, 19 | "control" means (i) the power, direct or indirect, to cause the 20 | direction or management of such entity, whether by contract or 21 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 22 | outstanding shares, or (iii) beneficial ownership of such entity. 23 | 24 | "You" (or "Your") shall mean an individual or Legal Entity 25 | exercising permissions granted by this License. 26 | 27 | "Source" form shall mean the preferred form for making modifications, 28 | including but not limited to software source code, documentation 29 | source, and configuration files. 30 | 31 | "Object" form shall mean any form resulting from mechanical 32 | transformation or translation of a Source form, including but 33 | not limited to compiled object code, generated documentation, 34 | and conversions to other media types. 35 | 36 | "Work" shall mean the work of authorship, whether in Source or 37 | Object form, made available under the License, as indicated by a 38 | copyright notice that is included in or attached to the work 39 | (an example is provided in the Appendix below). 40 | 41 | "Derivative Works" shall mean any work, whether in Source or Object 42 | form, that is based on (or derived from) the Work and for which the 43 | editorial revisions, annotations, elaborations, or other modifications 44 | represent, as a whole, an original work of authorship. For the purposes 45 | of this License, Derivative Works shall not include works that remain 46 | separable from, or merely link (or bind by name) to the interfaces of, 47 | the Work and Derivative Works thereof. 48 | 49 | "Contribution" shall mean any work of authorship, including 50 | the original version of the Work and any modifications or additions 51 | to that Work or Derivative Works thereof, that is intentionally 52 | submitted to Licensor for inclusion in the Work by the copyright owner 53 | or by an individual or Legal Entity authorized to submit on behalf of 54 | the copyright owner. For the purposes of this definition, "submitted" 55 | means any form of electronic, verbal, or written communication sent 56 | to the Licensor or its representatives, including but not limited to 57 | communication on electronic mailing lists, source code control systems, 58 | and issue tracking systems that are managed by, or on behalf of, the 59 | Licensor for the purpose of discussing and improving the Work, but 60 | excluding communication that is conspicuously marked or otherwise 61 | designated in writing by the copyright owner as "Not a Contribution." 62 | 63 | "Contributor" shall mean Licensor and any individual or Legal Entity 64 | on behalf of whom a Contribution has been received by Licensor and 65 | subsequently incorporated within the Work. 66 | 67 | 2. Grant of Copyright License. Subject to the terms and conditions of 68 | this License, each Contributor hereby grants to You a perpetual, 69 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 70 | copyright license to reproduce, prepare Derivative Works of, 71 | publicly display, publicly perform, sublicense, and distribute the 72 | Work and such Derivative Works in Source or Object form. 73 | 74 | 3. Grant of Patent License. Subject to the terms and conditions of 75 | this License, each Contributor hereby grants to You a perpetual, 76 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 77 | (except as stated in this section) patent license to make, have made, 78 | use, offer to sell, sell, import, and otherwise transfer the Work, 79 | where such license applies only to those patent claims licensable 80 | by such Contributor that are necessarily infringed by their 81 | Contribution(s) alone or by combination of their Contribution(s) 82 | with the Work to which such Contribution(s) was submitted. If You 83 | institute patent litigation against any entity (including a 84 | cross-claim or counterclaim in a lawsuit) alleging that the Work 85 | or a Contribution incorporated within the Work constitutes direct 86 | or contributory patent infringement, then any patent licenses 87 | granted to You under this License for that Work shall terminate 88 | as of the date such litigation is filed. 89 | 90 | 4. Redistribution. You may reproduce and distribute copies of the 91 | Work or Derivative Works thereof in any medium, with or without 92 | modifications, and in Source or Object form, provided that You 93 | meet the following conditions: 94 | 95 | (a) You must give any other recipients of the Work or 96 | Derivative Works a copy of this License; and 97 | 98 | (b) You must cause any modified files to carry prominent notices 99 | stating that You changed the files; and 100 | 101 | (c) You must retain, in the Source form of any Derivative Works 102 | that You distribute, all copyright, patent, trademark, and 103 | attribution notices from the Source form of the Work, 104 | excluding those notices that do not pertain to any part of 105 | the Derivative Works; and 106 | 107 | (d) If the Work includes a "NOTICE" text file as part of its 108 | distribution, then any Derivative Works that You distribute must 109 | include a readable copy of the attribution notices contained 110 | within such NOTICE file, excluding those notices that do not 111 | pertain to any part of the Derivative Works, in at least one 112 | of the following places: within a NOTICE text file distributed 113 | as part of the Derivative Works; within the Source form or 114 | documentation, if provided along with the Derivative Works; or, 115 | within a display generated by the Derivative Works, if and 116 | wherever such third-party notices normally appear. The contents 117 | of the NOTICE file are for informational purposes only and 118 | do not modify the License. You may add Your own attribution 119 | notices within Derivative Works that You distribute, alongside 120 | or as an addendum to the NOTICE text from the Work, provided 121 | that such additional attribution notices cannot be construed 122 | as modifying the License. 123 | 124 | You may add Your own copyright statement to Your modifications and 125 | may provide additional or different license terms and conditions 126 | for use, reproduction, or distribution of Your modifications, or 127 | for any such Derivative Works as a whole, provided Your use, 128 | reproduction, and distribution of the Work otherwise complies with 129 | the conditions stated in this License. 130 | 131 | 5. Submission of Contributions. Unless You explicitly state otherwise, 132 | any Contribution intentionally submitted for inclusion in the Work 133 | by You to the Licensor shall be under the terms and conditions of 134 | this License, without any additional terms or conditions. 135 | Notwithstanding the above, nothing herein shall supersede or modify 136 | the terms of any separate license agreement you may have executed 137 | with Licensor regarding such Contributions. 138 | 139 | 6. Trademarks. This License does not grant permission to use the trade 140 | names, trademarks, service marks, or product names of the Licensor, 141 | except as required for reasonable and customary use in describing the 142 | origin of the Work and reproducing the content of the NOTICE file. 143 | 144 | 7. Disclaimer of Warranty. Unless required by applicable law or 145 | agreed to in writing, Licensor provides the Work (and each 146 | Contributor provides its Contributions) on an "AS IS" BASIS, 147 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 148 | implied, including, without limitation, any warranties or conditions 149 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 150 | PARTICULAR PURPOSE. You are solely responsible for determining the 151 | appropriateness of using or redistributing the Work and assume any 152 | risks associated with Your exercise of permissions under this License. 153 | 154 | 8. Limitation of Liability. In no event and under no legal theory, 155 | whether in tort (including negligence), contract, or otherwise, 156 | unless required by applicable law (such as deliberate and grossly 157 | negligent acts) or agreed to in writing, shall any Contributor be 158 | liable to You for damages, including any direct, indirect, special, 159 | incidental, or consequential damages of any character arising as a 160 | result of this License or out of the use or inability to use the 161 | Work (including but not limited to damages for loss of goodwill, 162 | work stoppage, computer failure or malfunction, or any and all 163 | other commercial damages or losses), even if such Contributor 164 | has been advised of the possibility of such damages. 165 | 166 | 9. Accepting Warranty or Additional Liability. While redistributing 167 | the Work or Derivative Works thereof, You may choose to offer, 168 | and charge a fee for, acceptance of support, warranty, indemnity, 169 | or other liability obligations and/or rights consistent with this 170 | License. However, in accepting such obligations, You may act only 171 | on Your own behalf and on Your sole responsibility, not on behalf 172 | of any other Contributor, and only if You agree to indemnify, 173 | defend, and hold each Contributor harmless for any liability 174 | incurred by, or claims asserted against, such Contributor by reason 175 | of your accepting any such warranty or additional liability. 176 | 177 | END OF TERMS AND CONDITIONS 178 | 179 | APPENDIX: How to apply the Apache License to your work. 180 | 181 | To apply the Apache License to your work, attach the following 182 | boilerplate notice, with the fields enclosed by brackets "[]" 183 | replaced with your own identifying information. (Don't include 184 | the brackets!) The text should be enclosed in the appropriate 185 | comment syntax for the file format. We also recommend that a 186 | file or class name and description of purpose be included on the 187 | same "printed page" as the copyright notice for easier 188 | identification within third-party archives. 189 | 190 | Copyright [2020] [name of copyright owner] 191 | 192 | Licensed under the Apache License, Version 2.0 (the "License"); 193 | you may not use this file except in compliance with the License. 194 | You may obtain a copy of the License at 195 | 196 | http://www.apache.org/licenses/LICENSE-2.0 197 | 198 | Unless required by applicable law or agreed to in writing, software 199 | distributed under the License is distributed on an "AS IS" BASIS, 200 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 201 | See the License for the specific language governing permissions and 202 | limitations under the License. -------------------------------------------------------------------------------- /contrastive_text_image_learning/vision_transformer/mixer_figure.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tml-epfl/sam-low-rank-features/0e92a35b7bba64adbae76e56694282fe047d71bb/contrastive_text_image_learning/vision_transformer/mixer_figure.png -------------------------------------------------------------------------------- /contrastive_text_image_learning/vision_transformer/model_cards/lit.md: -------------------------------------------------------------------------------- 1 | # Model Card: LiT (Locked image Tuning) 2 | 3 | Last updated: 2022-06-19 4 | 5 | Version: 1.0 6 | 7 | - This doc: https://github.com/google-research/vision_transformer/blob/main/model_cards/lit.md 8 | - Model Page: https://github.com/google-research/vision_transformer#lit-models 9 | - Other Links: 10 | [LiT Blogpost](https://ai.googleblog.com/2022/04/locked-image-tuning-adding-language.html), 11 | [LiT Paper], 12 | [LiT Demo](https://google-research.github.io/vision_transformer/lit/) 13 | 14 | A text/image input model that can be used to embed text/image individually, 15 | and compute similarities between embeddings of text/image pairs. This enables 16 | use cases like zero shot classification, or image/text retrieval. 17 | 18 | Note that this model card refers to the models that have been released on 19 | Github specifically (B16B_2, L16L). The [LiT Paper] also evaluates models that 20 | have not been released and use different datasets for training. The Colab 21 | [`lit.ipynb`] lists some more models (L16S, L16Ti) which are similar to L16L, 22 | but with a smaller text tower. 23 | 24 | [LiT Paper]: https://arxiv.org/abs/2111.07991 25 | [`lit.ipynb`]: https://colab.research.google.com/github/google-research/vision_transformer/blob/main/lit.ipynb 26 | 27 | ## Model Summary 28 | 29 | - Architecture: Multimodal model with transformer text encoder and transformer 30 | image encoder. 31 | - Inputs: Images presented in 224x224x3 input, text inputs are tokenized and 32 | cropped to the first 16 tokens. 33 | - Outputs: Image and text embeddings (of size 768 or 1024). 34 | - Person of contact: Andreas Steiner (Google Brain) 35 | - Model authors: Xioahua Zhai, Xiao Wang, Basil Mustafa, Andreas Steiner, 36 | Daniel Keysers, Alexander Kolesnikov, Lucas Beyer (Google Brain) 37 | 38 | Citation: 39 | 40 | ```bibtex 41 | @article{zhai2022lit, 42 | title={LiT: Zero-Shot Transfer with Locked-image Text Tuning}, 43 | author={Zhai, Xiaohua and Wang, Xiao and Mustafa, Basil and Steiner, Andreas and Keysers, Daniel and Kolesnikov, Alexander and Beyer, Lucas}, 44 | journal={CVPR}, 45 | year={2022} 46 | } 47 | ``` 48 | 49 | ## Model Data 50 | 51 | Training data: 52 | 53 | - [Pre-trained image-tower](http://arxiv.org/abs/2106.10270) (using the 54 | recommended checkpoints from the paper, Section 4.2) 55 | - [ImageNet-21k](https://www.image-net.org/static_files/papers/imagenet_cvpr09.pdf) 56 | - [BERT](http://arxiv.org/abs/1810.04805) pre-trained text tower 57 | - [BookCorpus](https://github.com/jackbandy/bookcorpus-datasheet) 58 | - English wikipedia 59 | - Multi-modal datasets 60 | - [CC12M](https://arxiv.org/abs/2102.08981) 61 | - [YFCC100M](https://arxiv.org/abs/1503.01817) 62 | 63 | Evaluation data (see also section [Evaluation Results](#evaluation-results) 64 | below): 65 | 66 | - Zero-shot classification 67 | - [ImageNet](https://www.image-net.org/static_files/papers/imagenet_cvpr09.pdf) 68 | - [ImageNet v2](http://arxiv.org/abs/1902.10811) 69 | - [CIFAR100](https://www.cs.toronto.edu/~kriz/learning-features-2009-TR.pdf) 70 | - [Pets37](https://ieeexplore.ieee.org/abstract/document/6248092) 71 | - [Resisc45](http://arxiv.org/abs/1703.00121) 72 | - Image-text retrieval 73 | - [MS-COCO Captions](https://arxiv.org/abs/1504.00325) 74 | 75 | ## Model Creation & Maintenance 76 | 77 | The model has been initialized from BERT & ViT checkpoints (see details above 78 | "training dataset"), and then contrastively tuned on CC12M and YFCC100M. 79 | 80 | All datasets have been released in previous publications independent from this 81 | model. The datasets and model are not regularly updated. 82 | 83 | The published B16B_2 and L16L models are medium sized and can be used on a normal 84 | computer, or on a single GPU/TPU. 85 | 86 | | Model | B16B_2 | L16L | 87 | | :--- | ---: | ---: | 88 | | Size | 474 MB | 2.4 GB | 89 | | Weights | 196M | 638M | 90 | | Layers | 2x12 | 2x24 | 91 | | Latency (single TPU core) | 1200/sec | 400/sec | 92 | 93 | Software/hardware used for training: 94 | 95 | - JAX 0.3.13, Flax 0.5.0 96 | - 128 TPUv4 cores 97 | 98 | Software/hardware used for deployment: 99 | 100 | - JAX 0.3.13, Flax 0.5.0 101 | - CPU/GPU/TPU 102 | 103 | Compute requirements for training: 104 | 105 | | Model | B16B_2 | L16L | 106 | | :--- | ---: | ---: | 107 | | Number of Chips | 64 | 64 | 108 | | Training Time (days) | 0.3 | 1 | 109 | | Total Computation (FLOPS) | 2.7E+19 | 9E+19 | 110 | | Measured Performance (TFLOPS/s) | 1153 | 1614 | 111 | | Energy Consumption (MWh) | 0.14 | 0.16 | 112 | 113 | Compute requirements for inference: 114 | 115 | | Model | B16B_2 | L16L | 116 | | :--- | ---: | ---: | 117 | | FLOPS/example | approx. 10 | approx. 30 | 118 | 119 | ## Evaluation Results 120 | 121 | Benchmark information: 122 | 123 | - Zero-shot classification (as explained in [CLIP Paper]) 124 | - We chose to evaluate a set of datasets that are commonly used, and provide 125 | insights where the model works very well (such as ImageNet v2 or CIFAR100), 126 | as well as where it is much more limited (such as Resisc45). 127 | - Image-text retrieval (Appendix section I.3 in [LiT Paper]) 128 | 129 | [CLIP Paper]: https://arxiv.org/abs/2103.00020 130 | 131 | Evaluation results: 132 | 133 | | Model | B16B_2 | L16L | 134 | | :--- | ---: | ---: | 135 | | ImageNet zero-shot | 73.9% | 75.7% | 136 | | ImageNet v2 zero-shot | 65.1% | 66.6% | 137 | | CIFAR100 zero-shot | 79.0% | 80.5% | 138 | | Pets37 zero-shot | 83.3% | 83.3% | 139 | | Resisc45 zero-shot | 25.3% | 25.6% | 140 | | MS-COCO Captions image-to-text retrieval | 51.6% | 48.5% | 141 | | MS-COCO Captions text-to-image retrieval | 31.8% | 31.1% | 142 | 143 | ## Limitations 144 | 145 | Known limitations: 146 | 147 | - Any deployment of this model, both for commercial applications and 148 | non-commercial applications, is currently out of scope. 149 | - Before using the model in a constrained (i.e. not deployed) environment, users 150 | should do in-depth testing for their specific use case (e.g. on a constrained 151 | set of class labels of interest). 152 | - These models have only been trained on English text and will fail for most 153 | non-English inputs. 154 | - These models have not been evaluated with respect to their biases and fairness 155 | aspects. We suspect that biases found in the datasets used for training will 156 | be replicated by model representations, and model predictions should a priori 157 | be considered to replicate these biases, with consequences to various fairness 158 | metrics. 159 | 160 | Ethical considerations & risks: 161 | 162 | - The publication is based on previous work ([CLIP Paper]) that has been shown 163 | (Section 7) to replicate gender biases, perform variably for different groups 164 | of people (by gender, skin color), and cause representational harm in varying 165 | degree for different groups of people (by age, skin color). In the same 166 | section, previous authors have shown that a discriminative image/text model 167 | has the potential to be used in a surveillance context for coarse 168 | classification (although not for fine-grained classification), potentially 169 | lowering the barrier for such problematic use cases. 170 | - These models have not been evaluated for the problems mentioned in previous 171 | work, but until such an evaluation is performed, we expect similar risks. 172 | 173 | ## Model Usage 174 | 175 | Sensitive use: The model has been trained on image datasets containing 176 | pictures of people, both for the pre-training of the image encoder 177 | (ImageNet-21k), and for the contrastive tuning (CC12M and YFCC100M). 178 | 179 | The model is used exclusively in research for now: 180 | 181 | - [Zero-Shot Text-Guided Object Generation with Dream Fields](https://arxiv.org/abs/2112.01455) 182 | - [Simple Open-Vocabulary Object Detection with Vision Transformers](https://arxiv.org/abs/2205.06230) 183 | 184 | ## Model Comparison 185 | 186 | In comparison with "private data" model from [CLIP Paper]: 187 | 188 | - As of 6/10/22, the best published CLIP model is the L/14-336px variant. 189 | - Similar performance (e.g. ImageNet zero-shot classification accuracy: 190 | 76.2% CLIP vs. LiT L16L 75.7%) 191 | - LiT is trained solely on publicly available datasets, while CLIP is trained on 192 | a private undisclosed dataset. 193 | - The LiT L16L model is considerably smaller: CLIP uses 576 tokens vs. LiT L16L 194 | uses 196 tokens – since the runtime/memory complexity of attention scales with 195 | the square of the number of tokens, this corresponds to a factor of 8.63x. 196 | 197 | In comparison with "public data" model from [CLIP Paper]: 198 | 199 | - The only model trained without the private data mentioned in the CLIP paper 200 | (Section D), namely on YFCC100M. 201 | - LiT has much better performance (e.g. ImageNet zero-shot classification 202 | accuracy: 31.3% CLIP vs. LiT L16L 75.7%) 203 | 204 | ## System Dependencies 205 | 206 | Can be used as a stand-alone model (e.g. for zero-shot classification or 207 | retrieval), or as part of a more complex system (basically any system that uses 208 | CLIP as a component can instead use a LiT model). 209 | 210 | Pre-processing instructions can be found on Github: 211 | [vit_jax/preprocess.py](https://github.com/google-research/vision_transformer/blob/main/vit_jax/preprocess.py). 212 | The published models include a pre-processing configuration (specifying 213 | tokenizer vocabulary and image pre-processing). 214 | 215 | The model outputs image and text embeddings and a temperature. If similarities 216 | are to be computed between image and text embeddings (e.g. for computing output 217 | distributions), then the similarities between the embeddings should be computed 218 | with the dot product, and these should then be multiplied by the temperature 219 | before a softmax is applied. 220 | 221 | ## Changelog 222 | 223 | - 2022-08-16: Replaced model B16B with an updated version B16B_2 that was 224 | trained for 60k steps (before: 30k) without linear head on the image side 225 | (before: 768) and has better performance. 226 | -------------------------------------------------------------------------------- /contrastive_text_image_learning/vision_transformer/setup.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 Google LLC. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """setup.py for vision_transformer repo, vit_jax package.""" 16 | 17 | import os 18 | from setuptools import find_packages 19 | from setuptools import setup 20 | 21 | 22 | here = os.path.abspath(os.path.dirname(__file__)) 23 | try: 24 | README = open(os.path.join(here, 'README.md'), encoding='utf-8').read() 25 | except IOError: 26 | README = '' 27 | 28 | install_requires = [ 29 | 'absl-py', 30 | 'clu', 31 | 'einops', 32 | 'flax', 33 | 'flaxformer @ git+https://github.com/google/flaxformer', 34 | 'jax', 35 | 'ml-collections', 36 | 'numpy', 37 | 'packaging', 38 | 'pandas', 39 | 'scipy', 40 | 'tensorflow_datasets', 41 | 'tensorflow_probability', 42 | 'tensorflow', 43 | 'tensorflow_text', 44 | 'tqdm', 45 | ] 46 | 47 | tests_require = [ 48 | 'pytest', 49 | ] 50 | 51 | __version__ = None 52 | 53 | with open(os.path.join(here, 'version.py')) as f: 54 | exec(f.read(), globals()) # pylint: disable=exec-used 55 | 56 | setup( 57 | name='vit_jax', 58 | version=__version__, 59 | description='Original JAX implementation of Vision Transformer models.', 60 | long_description=README, 61 | long_description_content_type='text/markdown', 62 | classifiers=[ 63 | 'Development Status :: 3 - Alpha', 64 | 'Intended Audience :: Developers', 65 | 'Intended Audience :: Science/Research', 66 | 'License :: OSI Approved :: Apache Software License', 67 | 'Programming Language :: Python :: 3.7', 68 | 'Topic :: Scientific/Engineering :: Artificial Intelligence', 69 | ], 70 | keywords='', 71 | author='Vision Transformer Authors', 72 | author_email='no-reply@google.com', 73 | url='https://github.com/google-research/vision_transformer', 74 | packages=find_packages(), 75 | zip_safe=False, 76 | install_requires=install_requires, 77 | tests_require=tests_require, 78 | extras_require=dict(test=tests_require), 79 | ) 80 | -------------------------------------------------------------------------------- /contrastive_text_image_learning/vision_transformer/version.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 Google LLC. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Current vision_transformer version at head on Github.""" 16 | 17 | __version__ = "0.0.8" 18 | -------------------------------------------------------------------------------- /contrastive_text_image_learning/vision_transformer/vit_figure.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tml-epfl/sam-low-rank-features/0e92a35b7bba64adbae76e56694282fe047d71bb/contrastive_text_image_learning/vision_transformer/vit_figure.png -------------------------------------------------------------------------------- /contrastive_text_image_learning/vision_transformer/vit_jax/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 Google LLC. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | -------------------------------------------------------------------------------- /contrastive_text_image_learning/vision_transformer/vit_jax/checkpoint.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 Google LLC. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import collections 16 | from collections import abc 17 | import re 18 | 19 | from absl import logging 20 | import flax 21 | from flax.training import checkpoints 22 | import jax.numpy as jnp 23 | import numpy as np 24 | from packaging import version 25 | import pandas as pd 26 | import scipy.ndimage 27 | from tensorflow.io import gfile # pylint: disable=import-error 28 | import tqdm 29 | 30 | 31 | def _flatten_dict(d, parent_key='', sep='/'): 32 | """Flattens a dictionary, keeping empty leaves.""" 33 | items = [] 34 | for k, v in d.items(): 35 | path = parent_key + sep + k if parent_key else k 36 | if isinstance(v, abc.Mapping): 37 | items.extend(_flatten_dict(v, path, sep=sep).items()) 38 | else: 39 | items.append((path, v)) 40 | 41 | # Keeps the empty dict if it was set explicitly. 42 | if parent_key and not d: 43 | items.append((parent_key, {})) 44 | 45 | return dict(items) 46 | 47 | 48 | def inspect_params(*, 49 | params, 50 | expected, 51 | fail_if_extra=True, 52 | fail_if_missing=True): 53 | """Inspects whether the params are consistent with the expected keys.""" 54 | params_flat = _flatten_dict(params) 55 | expected_flat = _flatten_dict(expected) 56 | missing_keys = expected_flat.keys() - params_flat.keys() 57 | extra_keys = params_flat.keys() - expected_flat.keys() 58 | 59 | # Adds back empty dict explicitly, to support layers without weights. 60 | # Context: FLAX ignores empty dict during serialization. 61 | empty_keys = set() 62 | for k in missing_keys: 63 | if isinstance(expected_flat[k], dict) and not expected_flat[k]: 64 | params[k] = {} 65 | empty_keys.add(k) 66 | missing_keys -= empty_keys 67 | 68 | if empty_keys: 69 | logging.warning('Inspect recovered empty keys:\n%s', empty_keys) 70 | if missing_keys: 71 | logging.info('Inspect missing keys:\n%s', missing_keys) 72 | if extra_keys: 73 | logging.info('Inspect extra keys:\n%s', extra_keys) 74 | 75 | if (missing_keys and fail_if_missing) or (extra_keys and fail_if_extra): 76 | raise ValueError(f'Missing params from checkpoint: {missing_keys}.\n' 77 | f'Extra params in checkpoint: {extra_keys}.\n' 78 | f'Restored params from checkpoint: {params_flat.keys()}.\n' 79 | f'Expected params from code: {expected_flat.keys()}.') 80 | if 'extra' in params: 81 | params = params['opt']['state']['param_states'] 82 | return params 83 | 84 | 85 | def recover_tree(keys, values): 86 | """Recovers a tree as a nested dict from flat names and values. 87 | 88 | This function is useful to analyze checkpoints that are without need to access 89 | the exact source code of the experiment. In particular, it can be used to 90 | extract an reuse various subtrees of the scheckpoint, e.g. subtree of 91 | parameters. 92 | 93 | Args: 94 | keys: a list of keys, where '/' is used as separator between nodes. 95 | values: a list of leaf values. 96 | 97 | Returns: 98 | A nested tree-like dict. 99 | """ 100 | tree = {} 101 | sub_trees = collections.defaultdict(list) 102 | for k, v in zip(keys, values): 103 | if '/' not in k: 104 | tree[k] = v 105 | else: 106 | k_left, k_right = k.split('/', 1) 107 | sub_trees[k_left].append((k_right, v)) 108 | for k, kv_pairs in sub_trees.items(): 109 | k_subtree, v_subtree = zip(*kv_pairs) 110 | tree[k] = recover_tree(k_subtree, v_subtree) 111 | return tree 112 | 113 | 114 | def copy(src, dst, progress=True, block_size=1024 * 1024 * 10): 115 | """Copies a file with progress bar. 116 | 117 | Args: 118 | src: Source file. Path must be readable by `tf.io.gfile`. 119 | dst: Destination file. Path must be readable by `tf.io.gfile`. 120 | progress: Whether to show a progres bar. 121 | block_size: Size of individual blocks to be read/written. 122 | """ 123 | stats = gfile.stat(src) 124 | n = int(np.ceil(stats.length / block_size)) 125 | range_or_trange = tqdm.trange if progress else range 126 | with gfile.GFile(src, 'rb') as fin: 127 | with gfile.GFile(dst, 'wb') as fout: 128 | for _ in range_or_trange(n): 129 | fout.write(fin.read(block_size)) 130 | 131 | 132 | def load(path): 133 | """Loads params from a checkpoint previously stored with `save()`.""" 134 | with gfile.GFile(path, 'rb') as f: 135 | ckpt_dict = np.load(f, allow_pickle=False) 136 | keys, values = zip(*list(ckpt_dict.items())) 137 | params = checkpoints.convert_pre_linen(recover_tree(keys, values)) 138 | if isinstance(params, flax.core.FrozenDict): 139 | params = params.unfreeze() 140 | if version.parse(flax.__version__) >= version.parse('0.3.6'): 141 | params = _fix_groupnorm(params) 142 | return params 143 | 144 | 145 | def _fix_groupnorm(params): 146 | # See https://github.com/google/flax/issues/1721 147 | regex = re.compile(r'gn(\d+|_root|_proj)$') 148 | 149 | def fix_gn(args): 150 | path, array = args 151 | if len(path) > 1 and regex.match( 152 | path[-2]) and path[-1] in ('bias', 'scale'): 153 | array = array.squeeze() 154 | return (path, array) 155 | 156 | return flax.traverse_util.unflatten_dict( 157 | dict(map(fix_gn, 158 | flax.traverse_util.flatten_dict(params).items()))) 159 | 160 | 161 | def load_pretrained(*, pretrained_path, init_params, model_config): 162 | """Loads/converts a pretrained checkpoint for fine tuning. 163 | 164 | Args: 165 | pretrained_path: File pointing to pretrained checkpoint. 166 | init_params: Parameters from model. Will be used for the head of the model 167 | and to verify that the model is compatible with the stored checkpoint. 168 | model_config: Configuration of the model. Will be used to configure the head 169 | and rescale the position embeddings. 170 | 171 | Returns: 172 | Parameters like `init_params`, but loaded with pretrained weights from 173 | `pretrained_path` and adapted accordingly. 174 | """ 175 | 176 | restored_params = inspect_params( 177 | params=load(pretrained_path), 178 | expected=init_params, 179 | fail_if_extra=False, 180 | fail_if_missing=False) 181 | 182 | # The following allows implementing fine-tuning head variants depending on the 183 | # value of `representation_size` in the fine-tuning job: 184 | # - `None` : drop the whole head and attach a nn.Linear. 185 | # - same number as in pre-training means : keep the head but reset the last 186 | # layer (logits) for the new task. 187 | if model_config.get('representation_size') is None: 188 | if 'pre_logits' in restored_params: 189 | logging.info('load_pretrained: drop-head variant') 190 | restored_params['pre_logits'] = {} 191 | restored_params['head']['kernel'] = init_params['head']['kernel'] 192 | restored_params['head']['bias'] = init_params['head']['bias'] 193 | 194 | if 'posembed_input' in restored_params.get('Transformer', {}): 195 | # Rescale the grid of position embeddings. Param shape is (1,N,1024) 196 | posemb = restored_params['Transformer']['posembed_input']['pos_embedding'] 197 | posemb_new = init_params['Transformer']['posembed_input']['pos_embedding'] 198 | if posemb.shape != posemb_new.shape: 199 | logging.info('load_pretrained: resized variant: %s to %s', posemb.shape, 200 | posemb_new.shape) 201 | posemb = interpolate_posembed( 202 | posemb, posemb_new.shape[1], model_config.classifier == 'token') 203 | restored_params['Transformer']['posembed_input']['pos_embedding'] = posemb 204 | 205 | if version.parse(flax.__version__) >= version.parse('0.3.6'): 206 | restored_params = _fix_groupnorm(restored_params) 207 | 208 | return flax.core.freeze(restored_params) 209 | 210 | 211 | def interpolate_posembed(posemb, num_tokens: int, has_class_token: bool): 212 | """Interpolate given positional embedding parameters into a new shape. 213 | 214 | Args: 215 | posemb: positional embedding parameters. 216 | num_tokens: desired number of tokens. 217 | has_class_token: True if the positional embedding parameters contain a 218 | class token. 219 | 220 | Returns: 221 | Positional embedding parameters interpolated into the new shape. 222 | """ 223 | assert posemb.shape[0] == 1 224 | if has_class_token: 225 | posemb_tok, posemb_grid = posemb[:, :1], posemb[0, 1:] 226 | num_tokens -= 1 227 | else: 228 | posemb_tok, posemb_grid = posemb[:, :0], posemb[0, 0:] 229 | 230 | gs_old = int(np.sqrt(len(posemb_grid))) 231 | gs_new = int(np.sqrt(num_tokens)) 232 | logging.info('interpolate_posembed: grid-size from %s to %s', gs_old, gs_new) 233 | assert gs_old ** 2 == len(posemb_grid), f'{gs_old ** 2} != {len(posemb_grid)}' 234 | assert gs_new ** 2 == num_tokens, f'{gs_new ** 2} != {num_tokens}' 235 | posemb_grid = posemb_grid.reshape(gs_old, gs_old, -1) 236 | 237 | zoom = (gs_new / gs_old, gs_new / gs_old, 1) 238 | posemb_grid = scipy.ndimage.zoom(posemb_grid, zoom, order=1) 239 | posemb_grid = posemb_grid.reshape(1, gs_new * gs_new, -1) 240 | return jnp.array(np.concatenate([posemb_tok, posemb_grid], axis=1)) 241 | 242 | 243 | def get_augreg_df(directory='gs://vit_models/augreg'): 244 | """Reads DataFrame describing AugReg models from GCS bucket. 245 | 246 | This function returns a dataframe that describes the models that were 247 | published as part of the paper "How to train your ViT? Data, Augmentation, and 248 | Regularization in Vision Transformers" (https://arxiv.org/abs/TODO). 249 | 250 | Note that every row in the dataset corresponds to a pre-training checkpoint 251 | (column "filename"), and a fine-tuning checkpoint (column "adapt_filename"). 252 | Every pre-trained checkpoint is fine-tuned many times. 253 | 254 | Args: 255 | directory: Pathname of directory containing "index.csv" 256 | 257 | Returns: 258 | Dataframe with the following columns: 259 | - name: Name of the model, as used in descriptions in paper (e.g. "B/16", 260 | or "R26+S/32"). 261 | - ds: Dataset used for pre-training: "i1k" (300 epochs), "i21k" (300 262 | epochs), and "i21k_30" (30 epochs). 263 | - lr: Learning rate used for pre-training. 264 | - aug: Data augmentation used for pre-training. Refer to paper for 265 | details. 266 | - wd: Weight decay used for pre-training. 267 | - do: Dropout used for pre-training. 268 | - sd: Stochastic depth used for pre-training. 269 | - best_val: Best accuracy on validation set that was reached during the 270 | pre-training. Note that "validation set" can refer to minival (meaning 271 | split from training set, as for example for "imagenet2012" dataset). 272 | - final_val: Final validation set accuracy. 273 | - final_test: Final testset accuracy (in cases where there is no official 274 | testset, like for "imagenet2012", this refers to the validation set). 275 | - adapt_ds: What dataset was used for fine-tuning. 276 | - adapt_lr: Learning rate used for fine-tuning. 277 | - adapt_steps: Number of steps used for fine-tuning (with a fixed batch 278 | size of 512). 279 | - adapt_resolution: Resolution that was used for fine-tuning. 280 | - adapt_final_val: Final validation accuracy after fine-tuning. 281 | - adapt_final_test: Final test accuracy after fine-tuning. 282 | - params: Number of parameters. 283 | - infer_samples_per_sec: Numbers of sample per seconds during inference on 284 | a V100 GPU (measured with `timm` implementation). 285 | - filename: Name of the pre-training checkpoint. Can be found at 286 | "gs://vit_models/augreg/{filename}.npz". 287 | - adapt_filename: Name of the fine-tuning checkpoint. 288 | """ 289 | with gfile.GFile(f'{directory}/index.csv') as f: 290 | return pd.read_csv(f) 291 | -------------------------------------------------------------------------------- /contrastive_text_image_learning/vision_transformer/vit_jax/checkpoint_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 Google LLC. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import tempfile 16 | 17 | from absl.testing import absltest 18 | import jax 19 | import jax.numpy as jnp 20 | 21 | from vit_jax import checkpoint 22 | from vit_jax import models 23 | from vit_jax import test_utils 24 | from vit_jax.configs import models as config_lib 25 | 26 | 27 | class CheckpointTest(absltest.TestCase): 28 | 29 | def test_load_pretrained(self): 30 | tempdir = tempfile.gettempdir() 31 | model_config = config_lib.get_testing_config() 32 | test_utils.create_checkpoint(model_config, f'{tempdir}/testing.npz') 33 | model = models.VisionTransformer(num_classes=2, **model_config) 34 | variables = model.init( 35 | jax.random.PRNGKey(0), 36 | inputs=jnp.ones([1, 32, 32, 3], jnp.float32), 37 | train=False, 38 | ) 39 | checkpoint.load_pretrained( 40 | pretrained_path=f'{tempdir}/testing.npz', 41 | init_params=variables['params'], 42 | model_config=model_config) 43 | 44 | 45 | if __name__ == '__main__': 46 | absltest.main() 47 | -------------------------------------------------------------------------------- /contrastive_text_image_learning/vision_transformer/vit_jax/configs/README.md: -------------------------------------------------------------------------------- 1 | # Configs 2 | 3 | This directory contains `ml_collections.ConfigDict` configurations. It is 4 | structured in a way that factors out common configuration parameters into 5 | `common.py` and model configurations into `models.py`. 6 | 7 | To select one of these configurations you can specify it on the command line: 8 | 9 | ```sh 10 | python -m vit_jax.main --config=$(pwd)/vit_jax/configs/vit.py:b32,cifar10 11 | ``` 12 | 13 | The above example specifies the additional parameter `b32,cifar10` that is 14 | parsed in the file `vit.py` and parametrizes the configuration. 15 | 16 | Note that you can override any configuration parameters at the command line by 17 | specifying additional parameters like `--config.accumulation_steps=1`. 18 | -------------------------------------------------------------------------------- /contrastive_text_image_learning/vision_transformer/vit_jax/configs/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 Google LLC. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | -------------------------------------------------------------------------------- /contrastive_text_image_learning/vision_transformer/vit_jax/configs/augreg.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 Google LLC. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | r"""Fine-tunes a Vision Transformer / Hybrid from AugReg checkpoint. 16 | 17 | Example for fine-tuning a R+Ti/16 on cifar100: 18 | 19 | python -m vit_jax.main --workdir=/tmp/vit \ 20 | --config=$(pwd)/vit_jax/configs/augreg.py:R_Ti_16 \ 21 | --config.dataset=oxford_iiit_pet \ 22 | --config.pp.train='train[:90%]' \ 23 | --config.base_lr=0.01 24 | 25 | Note that by default, the best i21k pre-trained checkpoint by upstream 26 | validation accuracy is chosen. You can also manually select a model by 27 | specifying the full name (without ".npz" extension): 28 | 29 | python -m vit_jax.main --workdir=/tmp/vit \ 30 | --config=$(pwd)/vit_jax/configs/augreg.py:R_Ti_16-i21k-300ep-lr_0.001-aug_none-wd_0.03-do_0.0-sd_0.0 \ 31 | --config.dataset=oxford_iiit_pet \ 32 | --config.pp.train='train[:90%]' \ 33 | --config.base_lr=0.01 34 | """ 35 | 36 | import ml_collections 37 | 38 | from vit_jax.configs import common 39 | from vit_jax.configs import models 40 | 41 | 42 | def get_config(model_or_filename): 43 | """Returns default parameters for finetuning ViT `model` on `dataset`.""" 44 | config = common.get_config() 45 | 46 | config.pretrained_dir = 'gs://vit_models/augreg' 47 | 48 | config.model_or_filename = model_or_filename 49 | model = model_or_filename.split('-')[0] 50 | if model not in models.AUGREG_CONFIGS: 51 | raise ValueError(f'Unknown Augreg model "{model}"' 52 | f'- not found in {set(models.AUGREG_CONFIGS.keys())}') 53 | config.model = models.AUGREG_CONFIGS[model].copy_and_resolve_references() 54 | config.model.transformer.dropout_rate = 0 # No AugReg during fine-tuning. 55 | 56 | # These values are often overridden on the command line. 57 | config.base_lr = 0.03 58 | config.total_steps = 500 59 | config.warmup_steps = 100 60 | config.pp = ml_collections.ConfigDict() 61 | config.pp.train = 'train' 62 | config.pp.test = 'test' 63 | config.pp.resize = 448 64 | config.pp.crop = 384 65 | 66 | # This value MUST be overridden on the command line. 67 | config.dataset = '' 68 | 69 | return config 70 | -------------------------------------------------------------------------------- /contrastive_text_image_learning/vision_transformer/vit_jax/configs/common.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 Google LLC. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from typing import Any, Dict, Iterable, Tuple, Union 16 | 17 | import ml_collections 18 | 19 | 20 | def get_config(): 21 | """Returns config values other than model parameters.""" 22 | 23 | config = ml_collections.ConfigDict() 24 | 25 | # Where to search for pretrained ViT models. 26 | # Can be downloaded from gs://vit_models/imagenet21k 27 | config.pretrained_dir = '.' 28 | # Which dataset to finetune on. This can be the name of a tfds dataset 29 | # (see https://www.tensorflow.org/datasets/catalog/overview), or the path to 30 | # a directory with the following structure ($filename can be arbitrary): 31 | # "{train,test}/$class_name/$filename.jpg" 32 | config.dataset = '' 33 | # Path to manually downloaded dataset 34 | config.tfds_manual_dir = None 35 | # Path to tensorflow_datasets directory 36 | config.tfds_data_dir = None 37 | # Number of steps; determined by hyper module if not specified. 38 | config.total_steps = None 39 | 40 | # Resizes global gradients. 41 | config.grad_norm_clip = 1.0 42 | # Datatype to use for momentum state ("bfloat16" or "float32"). 43 | config.optim_dtype = 'bfloat16' 44 | # Accumulate gradients over multiple steps to save on memory. 45 | config.accum_steps = 8 46 | 47 | # Batch size for training. 48 | config.batch = 512 49 | # Batch size for evaluation. 50 | config.batch_eval = 512 51 | # Shuffle buffer size. 52 | config.shuffle_buffer = 50_000 53 | # Run prediction on validation set every so many steps 54 | config.eval_every = 100 55 | # Log progress every so many steps. 56 | config.progress_every = 10 57 | # How often to write checkpoints. Specifying 0 disables checkpointing. 58 | config.checkpoint_every = 1_000 59 | 60 | # Number of batches to prefetch to device. 61 | config.prefetch = 2 62 | 63 | # Base learning-rate for fine-tuning. 64 | config.base_lr = 0.03 65 | # How to decay the learning rate ("cosine" or "linear"). 66 | config.decay_type = 'cosine' 67 | # How to decay the learning rate. 68 | config.warmup_steps = 500 69 | 70 | # Alternatives : inference_time. 71 | config.trainer = 'train' 72 | 73 | # Will be set from ./models.py 74 | config.model = None 75 | # Only used in ./augreg.py configs 76 | config.model_or_filename = None 77 | # Must be set via `with_dataset()` 78 | config.dataset = None 79 | config.pp = None 80 | 81 | return config.lock() 82 | 83 | 84 | # We leave out a subset of training for validation purposes (if needed). 85 | DATASET_PRESETS = { 86 | 'cifar10': ml_collections.ConfigDict( 87 | {'total_steps': 10_000, 88 | 'pp': ml_collections.ConfigDict( 89 | {'train': 'train[:98%]', 90 | 'test': 'test', 91 | 'crop': 384}) 92 | }), 93 | 'cifar100': ml_collections.ConfigDict( 94 | {'total_steps': 10_000, 95 | 'pp': ml_collections.ConfigDict( 96 | {'train': 'train[:98%]', 97 | 'test': 'test', 98 | 'crop': 384}) 99 | }), 100 | 'imagenet2012': ml_collections.ConfigDict( 101 | {'total_steps': 20_000, 102 | 'pp': ml_collections.ConfigDict( 103 | {'train': 'train[:99%]', 104 | 'test': 'validation', 105 | 'crop': 384}) 106 | }), 107 | } 108 | 109 | 110 | def with_dataset(config: ml_collections.ConfigDict, 111 | dataset: str) -> ml_collections.ConfigDict: 112 | config = ml_collections.ConfigDict(config.to_dict()) 113 | config.dataset = dataset 114 | config.update(DATASET_PRESETS[dataset]) 115 | return config 116 | 117 | 118 | def flatten( 119 | config: Union[ml_collections.ConfigDict, Dict[str, Any]], 120 | prefix: Tuple[str, ...] = ('config',) 121 | ) -> Iterable[Tuple[str, Any]]: 122 | """Returns a flat representation of `config`, e.g. for use in sweeps.""" 123 | for k, v in config.items(): 124 | if isinstance(v, (dict, ml_collections.ConfigDict)): 125 | yield from flatten(v, prefix + (k,)) 126 | else: 127 | yield ('.'.join(prefix + (k,)), v) 128 | -------------------------------------------------------------------------------- /contrastive_text_image_learning/vision_transformer/vit_jax/configs/inference_time.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 Google LLC. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import ml_collections 16 | 17 | 18 | def get_config(): 19 | """Returns a configuration for inference_time.py.""" 20 | config = ml_collections.ConfigDict() 21 | 22 | # Which model to use -- see ./models.py 23 | config.model_name = 'ViT-B_32' 24 | # Where to store training logs. 25 | config.log_dir = '.' 26 | 27 | # Number of steps to measure. 28 | config.steps = 30 29 | # Number of steps before measuring. 30 | config.initial_steps = 10 31 | 32 | # Batch size 33 | config.batch = 0 34 | # Number of output classes. 35 | config.num_classes = 0 36 | # Image size (width=height). 37 | config.image_size = 0 38 | 39 | config.train = 'inference_time' 40 | 41 | return config 42 | -------------------------------------------------------------------------------- /contrastive_text_image_learning/vision_transformer/vit_jax/configs/mixer_base16_cifar10.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 Google LLC. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import ml_collections 16 | 17 | from vit_jax.configs import common 18 | from vit_jax.configs import models 19 | 20 | 21 | def get_config(): 22 | """Returns config for training Mixer-B/16 on cifar10.""" 23 | config = common.get_config() 24 | config.model_type = 'Mixer' 25 | config.model = models.get_mixer_b16_config() 26 | config.dataset = 'cifar10' 27 | config.total_steps = 10_000 28 | config.pp = ml_collections.ConfigDict( 29 | {'train': 'train[:98%]', 'test': 'test', 'crop': 224}) 30 | return config 31 | -------------------------------------------------------------------------------- /contrastive_text_image_learning/vision_transformer/vit_jax/configs/vit.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 Google LLC. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | r"""Fine-tunes a Vision Transformer. 16 | 17 | Example for fine-tuning a ViT-B/16 on CIFAR10: 18 | 19 | python -m vit_jax.main --workdir=/tmp/vit \ 20 | --config=$(pwd)/vit_jax/configs/vit.py:b16,cifar10 \ 21 | --config.pretrained_dir='gs://vit_models/imagenet21k' 22 | """ 23 | 24 | from vit_jax.configs import common 25 | from vit_jax.configs import models 26 | 27 | 28 | def get_config(model_dataset): 29 | """Returns default parameters for finetuning ViT `model` on `dataset`.""" 30 | model, dataset = model_dataset.split(',') 31 | config = common.with_dataset(common.get_config(), dataset) 32 | get_model_config = getattr(models, f'get_{model}_config') 33 | config.model = get_model_config() 34 | 35 | if model == 'b16' and dataset == 'cifar10': 36 | config.base_lr = 0.01 37 | 38 | return config 39 | -------------------------------------------------------------------------------- /contrastive_text_image_learning/vision_transformer/vit_jax/inference_time.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 Google LLC. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import functools 16 | import os 17 | import time 18 | 19 | from absl import logging 20 | from clu import metric_writers 21 | import flax 22 | import flax.jax_utils as flax_utils 23 | import jax 24 | import jax.numpy as jnp 25 | import ml_collections 26 | import numpy as np 27 | import tensorflow as tf 28 | 29 | from vit_jax import checkpoint 30 | from vit_jax import models 31 | from vit_jax.configs import models as config_lib 32 | 33 | 34 | def inference_time(config: ml_collections.ConfigDict, workdir: str): 35 | """Runs a number of steps and measures inference time.""" 36 | 37 | assert config.batch, f'Expected --config.batch={config.batch} > 0' 38 | assert config.num_classes, ( 39 | f'Expected --config.num_classes={config.num_classes} > 0') 40 | assert config.image_size, ( 41 | f'Expected --config.image_size={config.image_size} > 0') 42 | 43 | # Build VisionTransformer architecture 44 | model_config = config_lib.MODEL_CONFIGS[config.model_name] 45 | model = models.VisionTransformer( 46 | num_classes=config.num_classes, **model_config) 47 | 48 | # Make sure initial model parameters (before replication) are on CPU only. 49 | @functools.partial(jax.jit, backend='cpu') 50 | def init(rng): 51 | return model.init( 52 | rng, 53 | # Discard the "num_local_devices" dimension for initialization. 54 | inputs=jnp.ones([1, config.image_size, config.image_size, 3], 55 | jnp.float32), 56 | train=False) 57 | 58 | variables = init(jax.random.PRNGKey(0)) 59 | 60 | params_repl = flax_utils.replicate(variables['params']) 61 | 62 | # pmap replicates the models over all TPUs/GPUs 63 | vit_fn_repl = jax.pmap(functools.partial(model.apply, train=False)) 64 | images = jnp.ones([ 65 | jax.local_device_count(), config.batch // jax.local_device_count(), 66 | config.image_size, config.image_size, 3 67 | ], jnp.float32) 68 | 69 | writer = metric_writers.create_default_writer(workdir, asynchronous=False) 70 | writer.write_hparams(config.to_dict()) 71 | 72 | logging.info('Starting training loop; initial compile can take a while...') 73 | logits = vit_fn_repl(flax.core.FrozenDict(params=params_repl), images) 74 | logits.block_until_ready() 75 | logging.info('Done.') 76 | 77 | logging.info('Going to run %d inferences WITHOUT measuring...', 78 | config.initial_steps) 79 | for _ in range(config.initial_steps): 80 | logits = vit_fn_repl(flax.core.FrozenDict(params=params_repl), images) 81 | logits.block_until_ready() 82 | 83 | logging.info('Going to run %d inferences measuring...', config.steps) 84 | times = [] 85 | for _ in range(config.initial_steps): 86 | t0 = time.time() 87 | logits = vit_fn_repl(flax.core.FrozenDict(params=params_repl), images) 88 | logits.block_until_ready() 89 | times.append(time.time() - t0) 90 | logging.info('times=%s', times) 91 | imgs_sec_core = config.batch / jax.local_device_count() / np.array(times) 92 | logging.info('imgs_sec_core_min=%f', imgs_sec_core.min()) 93 | logging.info('imgs_sec_core_max=%f', imgs_sec_core.max()) 94 | logging.info('imgs_sec_core_mean=%f', imgs_sec_core.mean()) 95 | logging.info('imgs_sec_core_std=%f', imgs_sec_core.std()) 96 | writer.write_scalars( 97 | 0, 98 | dict( 99 | imgs_sec_core_min=imgs_sec_core.min(), 100 | imgs_sec_core_max=imgs_sec_core.max(), 101 | imgs_sec_core_mean=imgs_sec_core.mean(), 102 | imgs_sec_core_std=imgs_sec_core.std(), 103 | )) 104 | -------------------------------------------------------------------------------- /contrastive_text_image_learning/vision_transformer/vit_jax/inference_time_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 Google LLC. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import glob 16 | import tempfile 17 | 18 | from absl.testing import absltest 19 | 20 | from vit_jax import inference_time 21 | from vit_jax import test_utils 22 | from vit_jax.configs import inference_time as config_lib 23 | from vit_jax.configs import models 24 | 25 | 26 | class InferenceTimeTest(absltest.TestCase): 27 | 28 | def test_main(self): 29 | config = config_lib.get_config() 30 | config.num_classes = 10 31 | config.image_size = 224 32 | config.batch = 8 33 | config.model_name = 'testing' 34 | model_config = models.get_testing_config() 35 | 36 | workdir = tempfile.gettempdir() 37 | config.pretrained_dir = workdir 38 | test_utils.create_checkpoint(model_config, f'{workdir}/testing.npz') 39 | inference_time.inference_time(config, workdir) 40 | self.assertNotEmpty(glob.glob(f'{workdir}/events.out.tfevents.*')) 41 | 42 | 43 | if __name__ == '__main__': 44 | absltest.main() 45 | -------------------------------------------------------------------------------- /contrastive_text_image_learning/vision_transformer/vit_jax/input_pipeline.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 Google LLC. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import glob 16 | import os 17 | 18 | from absl import logging 19 | import flax 20 | import jax 21 | import numpy as np 22 | import tensorflow as tf 23 | import tensorflow_datasets as tfds 24 | 25 | import sys 26 | if sys.platform != 'darwin': 27 | # A workaround to avoid crash because tfds may open to many files. 28 | import resource 29 | low, high = resource.getrlimit(resource.RLIMIT_NOFILE) 30 | resource.setrlimit(resource.RLIMIT_NOFILE, (high, high)) 31 | 32 | # Adjust depending on the available RAM. 33 | MAX_IN_MEMORY = 200_000 34 | 35 | 36 | def get_tfds_info(dataset, split): 37 | """Returns information about tfds dataset -- see `get_dataset_info()`.""" 38 | data_builder = tfds.builder(dataset) 39 | return dict( 40 | num_examples=data_builder.info.splits[split].num_examples, 41 | num_classes=data_builder.info.features['label'].num_classes, 42 | int2str=data_builder.info.features['label'].int2str, 43 | examples_glob=None, 44 | ) 45 | 46 | 47 | def get_directory_info(directory): 48 | """Returns information about directory dataset -- see `get_dataset_info()`.""" 49 | examples_glob = f'{directory}/*/*.jpg' 50 | paths = glob.glob(examples_glob) 51 | get_classname = lambda path: path.split('/')[-2] 52 | class_names = sorted(set(map(get_classname, paths))) 53 | return dict( 54 | num_examples=len(paths), 55 | num_classes=len(class_names), 56 | int2str=lambda id_: class_names[id_], 57 | examples_glob=examples_glob, 58 | ) 59 | 60 | 61 | def get_dataset_info(dataset, split): 62 | """Returns information about a dataset. 63 | 64 | Args: 65 | dataset: Name of tfds dataset or directory -- see `./configs/common.py` 66 | split: Which split to return data for (e.g. "test", or "train"; tfds also 67 | supports splits like "test[:90%]"). 68 | 69 | Returns: 70 | A dictionary with the following keys: 71 | - num_examples: Number of examples in dataset/mode. 72 | - num_classes: Number of classes in dataset. 73 | - int2str: Function converting class id to class name. 74 | - examples_glob: Glob to select all files, or None (for tfds dataset). 75 | """ 76 | directory = os.path.join(dataset, split) 77 | if os.path.isdir(directory): 78 | return get_directory_info(directory) 79 | return get_tfds_info(dataset, split) 80 | 81 | 82 | def get_datasets(config): 83 | """Returns `ds_train, ds_test` for specified `config`.""" 84 | 85 | if os.path.isdir(config.dataset): 86 | train_dir = os.path.join(config.dataset, 'train') 87 | test_dir = os.path.join(config.dataset, 'test') 88 | if not os.path.isdir(train_dir): 89 | raise ValueError('Expected to find directories"{}" and "{}"'.format( 90 | train_dir, 91 | test_dir, 92 | )) 93 | logging.info('Reading dataset from directories "%s" and "%s"', train_dir, 94 | test_dir) 95 | ds_train = get_data_from_directory( 96 | config=config, directory=train_dir, mode='train') 97 | ds_test = get_data_from_directory( 98 | config=config, directory=test_dir, mode='test') 99 | else: 100 | logging.info('Reading dataset from tfds "%s"', config.dataset) 101 | ds_train = get_data_from_tfds(config=config, mode='train') 102 | ds_test = get_data_from_tfds(config=config, mode='test') 103 | 104 | return ds_train, ds_test 105 | 106 | 107 | def get_data_from_directory(*, config, directory, mode): 108 | """Returns dataset as read from specified `directory`.""" 109 | 110 | dataset_info = get_directory_info(directory) 111 | data = tf.data.Dataset.list_files(dataset_info['examples_glob']) 112 | class_names = [ 113 | dataset_info['int2str'](id_) for id_ in range(dataset_info['num_classes']) 114 | ] 115 | 116 | def _pp(path): 117 | return dict( 118 | image=path, 119 | label=tf.where( 120 | tf.strings.split(path, '/')[-2] == class_names 121 | )[0][0], 122 | ) 123 | 124 | image_decoder = lambda path: tf.image.decode_jpeg(tf.io.read_file(path), 3) 125 | 126 | return get_data( 127 | data=data, 128 | mode=mode, 129 | num_classes=dataset_info['num_classes'], 130 | image_decoder=image_decoder, 131 | repeats=None if mode == 'train' else 1, 132 | batch_size=config.batch_eval if mode == 'test' else config.batch, 133 | image_size=config.pp['crop'], 134 | shuffle_buffer=min(dataset_info['num_examples'], config.shuffle_buffer), 135 | preprocess=_pp) 136 | 137 | 138 | def get_data_from_tfds(*, config, mode): 139 | """Returns dataset as read from tfds dataset `config.dataset`.""" 140 | 141 | data_builder = tfds.builder(config.dataset, data_dir=config.tfds_data_dir) 142 | 143 | data_builder.download_and_prepare( 144 | download_config=tfds.download.DownloadConfig( 145 | manual_dir=config.tfds_manual_dir)) 146 | data = data_builder.as_dataset( 147 | split=config.pp[mode], 148 | # Reduces memory footprint in shuffle buffer. 149 | decoders={'image': tfds.decode.SkipDecoding()}, 150 | shuffle_files=mode == 'train') 151 | image_decoder = data_builder.info.features['image'].decode_example 152 | 153 | dataset_info = get_tfds_info(config.dataset, config.pp[mode]) 154 | return get_data( 155 | data=data, 156 | mode=mode, 157 | num_classes=dataset_info['num_classes'], 158 | image_decoder=image_decoder, 159 | repeats=None if mode == 'train' else 1, 160 | batch_size=config.batch_eval if mode == 'test' else config.batch, 161 | image_size=config.pp['crop'], 162 | shuffle_buffer=min(dataset_info['num_examples'], config.shuffle_buffer)) 163 | 164 | 165 | def get_data(*, 166 | data, 167 | mode, 168 | num_classes, 169 | image_decoder, 170 | repeats, 171 | batch_size, 172 | image_size, 173 | shuffle_buffer, 174 | preprocess=None): 175 | """Returns dataset for training/eval. 176 | 177 | Args: 178 | data: tf.data.Dataset to read data from. 179 | mode: Must be "train" or "test". 180 | num_classes: Number of classes (used for one-hot encoding). 181 | image_decoder: Applied to `features['image']` after shuffling. Decoding the 182 | image after shuffling allows for a larger shuffle buffer. 183 | repeats: How many times the dataset should be repeated. For indefinite 184 | repeats specify None. 185 | batch_size: Global batch size. Note that the returned dataset will have 186 | dimensions [local_devices, batch_size / local_devices, ...]. 187 | image_size: Image size after cropping (for training) / resizing (for 188 | evaluation). 189 | shuffle_buffer: Number of elements to preload the shuffle buffer with. 190 | preprocess: Optional preprocess function. This function will be applied to 191 | the dataset just after repeat/shuffling, and before the data augmentation 192 | preprocess step is applied. 193 | """ 194 | 195 | def _pp(data): 196 | im = image_decoder(data['image']) 197 | if mode == 'train': 198 | channels = im.shape[-1] 199 | begin, size, _ = tf.image.sample_distorted_bounding_box( 200 | tf.shape(im), 201 | tf.zeros([0, 0, 4], tf.float32), 202 | area_range=(0.05, 1.0), 203 | min_object_covered=0, # Don't enforce a minimum area. 204 | use_image_if_no_bounding_boxes=True) 205 | im = tf.slice(im, begin, size) 206 | # Unfortunately, the above operation loses the depth-dimension. So we 207 | # need to restore it the manual way. 208 | im.set_shape([None, None, channels]) 209 | im = tf.image.resize(im, [image_size, image_size]) 210 | if tf.random.uniform(shape=[]) > 0.5: 211 | im = tf.image.flip_left_right(im) 212 | else: 213 | im = tf.image.resize(im, [image_size, image_size]) 214 | im = (im - 127.5) / 127.5 215 | label = tf.one_hot(data['label'], num_classes) # pylint: disable=no-value-for-parameter 216 | return {'image': im, 'label': label} 217 | 218 | data = data.repeat(repeats) 219 | if mode == 'train': 220 | data = data.shuffle(shuffle_buffer) 221 | if preprocess is not None: 222 | data = data.map(preprocess, tf.data.experimental.AUTOTUNE) 223 | data = data.map(_pp, tf.data.experimental.AUTOTUNE) 224 | data = data.batch(batch_size, drop_remainder=True) 225 | 226 | # Shard data such that it can be distributed accross devices 227 | num_devices = jax.local_device_count() 228 | 229 | def _shard(data): 230 | data['image'] = tf.reshape(data['image'], 231 | [num_devices, -1, image_size, image_size, 232 | data['image'].shape[-1]]) 233 | data['label'] = tf.reshape(data['label'], 234 | [num_devices, -1, num_classes]) 235 | return data 236 | 237 | if num_devices is not None: 238 | data = data.map(_shard, tf.data.experimental.AUTOTUNE) 239 | 240 | return data.prefetch(1) 241 | 242 | 243 | def prefetch(dataset, n_prefetch): 244 | """Prefetches data to device and converts to numpy array.""" 245 | ds_iter = iter(dataset) 246 | ds_iter = map(lambda x: jax.tree_map(lambda t: np.asarray(memoryview(t)), x), 247 | ds_iter) 248 | if n_prefetch: 249 | ds_iter = flax.jax_utils.prefetch_to_device(ds_iter, n_prefetch) 250 | return ds_iter 251 | -------------------------------------------------------------------------------- /contrastive_text_image_learning/vision_transformer/vit_jax/main.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 Google LLC. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from absl import app 16 | from absl import flags 17 | from absl import logging 18 | from clu import platform 19 | import jax 20 | from ml_collections import config_flags 21 | import tensorflow as tf 22 | 23 | from vit_jax import inference_time 24 | from vit_jax import train 25 | from vit_jax import utils 26 | 27 | FLAGS = flags.FLAGS 28 | 29 | _WORKDIR = flags.DEFINE_string('workdir', None, 30 | 'Directory to store logs and model data.') 31 | config_flags.DEFINE_config_file( 32 | 'config', 33 | None, 34 | 'File path to the training hyperparameter configuration.', 35 | lock_config=True) 36 | flags.mark_flags_as_required(['config', 'workdir']) 37 | # Flags --jax_backend_target and --jax_xla_backend are available through JAX. 38 | 39 | 40 | def main(argv): 41 | if len(argv) > 1: 42 | raise app.UsageError('Too many command-line arguments.') 43 | 44 | utils.add_gfile_logger(_WORKDIR.value) 45 | 46 | # Hide any GPUs form TensorFlow. Otherwise TF might reserve memory and make 47 | # it unavailable to JAX. 48 | tf.config.experimental.set_visible_devices([], 'GPU') 49 | 50 | jax.config.update('jax_log_compiles', True) 51 | 52 | logging.info('JAX process: %d / %d', jax.process_index(), jax.process_count()) 53 | logging.info('JAX local devices: %r', jax.local_devices()) 54 | jax_xla_backend = ('None' if FLAGS.jax_xla_backend is None else 55 | FLAGS.jax_xla_backend) 56 | logging.info('Using JAX XLA backend %s', jax_xla_backend) 57 | 58 | logging.info('Config: %s', FLAGS.config) 59 | 60 | # Add a note so that we can tell which task is which JAX host. 61 | # (Depending on the platform task 0 is not guaranteed to be host 0) 62 | platform.work_unit().set_task_status(f'process_index: {jax.process_index()}, ' 63 | f'process_count: {jax.process_count()}') 64 | platform.work_unit().create_artifact(platform.ArtifactType.DIRECTORY, 65 | _WORKDIR.value, 'workdir') 66 | 67 | if FLAGS.config.trainer == 'train': 68 | train.train_and_evaluate(FLAGS.config, _WORKDIR.value) 69 | elif FLAGS.config.trainer == 'inference_time': 70 | inference_time.inference_time(FLAGS.config, _WORKDIR.value) 71 | else: 72 | raise app.UsageError(f'Unknown trainer: {FLAGS.config.trainer}') 73 | 74 | 75 | if __name__ == '__main__': 76 | # Provide access to --jax_backend_target and --jax_xla_backend flags. 77 | jax.config.config_with_absl() 78 | app.run(main) 79 | -------------------------------------------------------------------------------- /contrastive_text_image_learning/vision_transformer/vit_jax/models.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 Google LLC. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from vit_jax import models_lit 16 | from vit_jax import models_mixer 17 | from vit_jax import models_vit 18 | from vit_jax.configs import models as model_configs 19 | 20 | # Note that you probably want to import the individual modules separately 21 | # instead (e.g. not depending on tensorflow_text required by models_lit if 22 | # you're only interested in image models). 23 | AddPositionEmbs = models_vit.AddPositionEmbs 24 | MlpBlock = models_vit.MlpBlock 25 | Encoder1DBlock = models_vit.Encoder1DBlock 26 | Encoder = models_vit.Encoder 27 | 28 | LitModel = models_lit.LitModel 29 | MlpMixer = models_mixer.MlpMixer 30 | VisionTransformer = models_vit.VisionTransformer 31 | 32 | 33 | def get_model(name, **kw): 34 | """Returns a model as specified in `model_configs.MODEL_CONFIGS`.""" 35 | if name.startswith('Mixer-'): 36 | return MlpMixer(**model_configs.MODEL_CONFIGS[name], **kw) 37 | elif name.startswith('LiT-'): 38 | return LitModel(**model_configs.MODEL_CONFIGS[name], **kw) 39 | else: 40 | return VisionTransformer(**model_configs.MODEL_CONFIGS[name], **kw) 41 | -------------------------------------------------------------------------------- /contrastive_text_image_learning/vision_transformer/vit_jax/models_lit.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 Google LLC. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Models from Locked-image text Tuning. 16 | 17 | See paper https://arxiv.org/abs/2111.07991 18 | """ 19 | 20 | import dataclasses 21 | import os 22 | from typing import Optional, Tuple 23 | 24 | import flax.linen as nn 25 | import jax.numpy as jnp 26 | import ml_collections 27 | from vit_jax import checkpoint 28 | from vit_jax import models_vit 29 | from vit_jax import preprocess 30 | 31 | from flaxformer.architectures.bert import bert 32 | from flaxformer.architectures.bert import configs 33 | 34 | 35 | BASE_PATH = 'gs://vit_models/lit' 36 | 37 | 38 | class BertModel(nn.Module): 39 | """BERT encoder with linear projection on last layer CLS token.""" 40 | 41 | config: str 42 | num_classes: Optional[int] = None 43 | 44 | @nn.compact 45 | def __call__(self, tokens): 46 | out = {} 47 | 48 | batch_size, max_len = tokens.shape 49 | bert_model = bert.BertEncoder(**dataclasses.asdict({ 50 | 'base': configs.BertBaseConfig(), 51 | 'large': configs.BertLargeConfig(), 52 | }[self.config])) 53 | x = out['transformed'] = bert_model( 54 | token_ids=tokens, 55 | position_ids=jnp.tile( 56 | jnp.arange(0, max_len, dtype=jnp.int32), [batch_size, 1]), 57 | segment_ids=jnp.zeros([batch_size, max_len], dtype=jnp.int32), 58 | input_mask=tokens.astype(jnp.bool_).astype(jnp.int32), 59 | enable_dropout=False, 60 | ) 61 | 62 | x = out['pre_logits'] = x[:, 0] # CLS token 63 | 64 | if self.num_classes: 65 | x = out['logits'] = nn.Dense(self.num_classes, name='head')(x) 66 | 67 | return x, out 68 | 69 | 70 | class TextTransformer(nn.Module): 71 | """Simple text transformer.""" 72 | 73 | num_classes: int 74 | width: int = 512 75 | num_layers: int = 12 76 | mlp_dim: int = 2048 77 | num_heads: int = 8 78 | dropout_rate: float = 0.0 79 | vocab_size: int = 32_000 80 | 81 | @nn.compact 82 | def __call__(self, x): 83 | out = {} 84 | 85 | embedding = nn.Embed(num_embeddings=self.vocab_size, features=self.width) 86 | x = out['embedded'] = embedding(x) 87 | 88 | # Add posemb 89 | n, l, d = x.shape # pylint: disable=unused-variable 90 | x = x + self.param('pos_embedding', 91 | nn.initializers.normal(stddev=1 / jnp.sqrt(d)), 92 | (1, l, d), x.dtype) 93 | 94 | x = models_vit.Encoder( 95 | num_layers=self.num_layers, 96 | mlp_dim=self.mlp_dim, 97 | num_heads=self.num_heads, 98 | dropout_rate=self.dropout_rate, 99 | attention_dropout_rate=0, 100 | add_position_embedding=False)( 101 | x, train=False) 102 | 103 | x = out['pre_logits'] = x[:, -1, :] # note that we take *last* token 104 | x = out['logits'] = nn.Dense(self.num_classes, name='head')(x) 105 | 106 | return x, out 107 | 108 | 109 | class LitModel(nn.Module): 110 | """Locked-image text Tuning model. 111 | 112 | See paper https://arxiv.org/abs/2111.07991 113 | 114 | For examples, refer to Colab 115 | 116 | https://colab.research.google.com/github/google-research/vision_transformer/blob/main/lit.ipynb 117 | 118 | Attributes: 119 | image: Configuration for ViT image tower. 120 | text: Configuration for text tower. 121 | pp: Preprocessing configuration. 122 | out_dim: Size of optional image/text heads that are added to the towers. 123 | model_name: Refers to the key in `model_configs.MODEL_CONFIGS`. 124 | """ 125 | 126 | image: ml_collections.ConfigDict 127 | text_model: str 128 | text: ml_collections.ConfigDict 129 | pp: ml_collections.ConfigDict 130 | out_dim: Tuple[Optional[int], Optional[int]] 131 | model_name: str 132 | 133 | def load_variables(self, path=None, cache=True): 134 | """Loads variables. 135 | 136 | Args: 137 | path: Path to load params from. If not specified, then the parms will be 138 | loaded from the default public Cloud storage path, unless they exist in 139 | the current working directory. 140 | cache: If set to `True` and `path` is not specified (the default), then 141 | the files will be copied from Cloud and stored in the current working 142 | directory. 143 | 144 | Returns: 145 | The module variables, to be used with `model.apply()` 146 | """ 147 | if path is None: 148 | local_path = f'{self.model_name}.npz' 149 | if not os.path.exists(local_path): 150 | path = f'{BASE_PATH}/{self.model_name}.npz' 151 | print('Loading params from cloud:', path) 152 | if cache: 153 | checkpoint.copy(path, local_path) 154 | if os.path.exists(local_path): 155 | print('\n⚠️ Reusing local copy:', local_path) 156 | path = local_path 157 | return {'params': checkpoint.load(path)} 158 | 159 | @property 160 | def vocab_path(self): 161 | ext = { 162 | 'bert': 'txt', 163 | 'sentencepiece': 'model', 164 | }[self.pp.tokenizer_name] 165 | return f'{BASE_PATH}/{self.model_name}.{ext}' 166 | 167 | def get_pp(self, crop=False): 168 | """Returns a preprocessing function suitable for `tf.data.Dataset.map()`.""" 169 | return preprocess.get_pp( 170 | tokenizer_name=self.pp.tokenizer_name, 171 | vocab_path=self.vocab_path, 172 | max_len=self.pp.max_len, 173 | size=self.pp.size, 174 | crop=crop) 175 | 176 | def get_tokenizer(self): 177 | """Returns a tokenizer.""" 178 | return preprocess.get_tokenizer(self.pp.tokenizer_name)( 179 | vocab_path=self.vocab_path, 180 | max_len=self.pp.max_len) 181 | 182 | def get_image_preprocessing(self, crop=False): 183 | """Returns a function to pre-process images (resize, value range).""" 184 | return preprocess.PreprocessImages(size=self.pp.size, crop=crop) 185 | 186 | @nn.compact 187 | def __call__(self, *, images=None, tokens=None): 188 | """Embeds images and/or tokens. 189 | 190 | Args: 191 | images: Batch of images, prepared with the function returned by 192 | `get_image_preprocessing()` or `get_pp()`. 193 | tokens: Batch of tokens, prepared with the function returned by 194 | `get_tokenizer()` or `get_pp()`. 195 | 196 | Returns: 197 | A tuple of `(zimg, ztxt, out)`, where `zimg` is a batch of embeddings for 198 | the images (or `None`, if images were not specified), `ztxt` is a batch 199 | of embeddings for the tokens (or `None`, if tokens were not specified), 200 | and `out` is a dictionary of additional values, such as `out['t']` that 201 | is the temperature multiplied with the vector dot products before the 202 | softmax is applied. 203 | """ 204 | 205 | # Support calling without text or without images, for example for few-shot. 206 | ztxt, zimg = None, None 207 | out = {} 208 | out_dims = self.out_dim 209 | if isinstance(out_dims, int): 210 | out_dims = (out_dims, out_dims) 211 | 212 | if tokens is not None: 213 | # Embed the text: 214 | model_class = { 215 | 'bert': BertModel, 216 | 'text_transformer': TextTransformer, 217 | }[self.text_model] 218 | text_model = model_class( 219 | **{ 220 | 'num_classes': out_dims[1], 221 | **(self.text or {}) 222 | }, name='txt') 223 | 224 | ztxt, out_txt = text_model(tokens) 225 | for k, v in out_txt.items(): 226 | out[f'txt/{k}'] = v 227 | 228 | # Normalize the embeddings the models give us. 229 | out['txt/norm'] = jnp.linalg.norm(ztxt, axis=1, keepdims=True) 230 | out['txt/normalized'] = ztxt = ztxt / (out['txt/norm'] + 1e-8) 231 | 232 | if images is not None: 233 | image_model = models_vit.VisionTransformer( 234 | **{ 235 | **self.image, 236 | 'num_classes': out_dims[0], 237 | }, name='img') # pylint: disable=not-a-mapping 238 | zimg = image_model(images, train=False) 239 | 240 | # Normalize the embeddings the models give us. 241 | out['img/norm'] = jnp.linalg.norm(zimg, axis=1, keepdims=True) 242 | out['img/normalized'] = zimg = zimg / (out['img/norm'] + 1e-8) 243 | 244 | t = self.param('t', nn.initializers.zeros, (1,), jnp.float32) 245 | out['t'] = jnp.exp(t) 246 | 247 | return zimg, ztxt, out 248 | -------------------------------------------------------------------------------- /contrastive_text_image_learning/vision_transformer/vit_jax/models_mixer.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 Google LLC. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from typing import Any, Optional 16 | 17 | import einops 18 | import flax.linen as nn 19 | import jax.numpy as jnp 20 | 21 | 22 | class MlpBlock(nn.Module): 23 | mlp_dim: int 24 | 25 | @nn.compact 26 | def __call__(self, x): 27 | y = nn.Dense(self.mlp_dim)(x) 28 | y = nn.gelu(y) 29 | return nn.Dense(x.shape[-1])(y) 30 | 31 | 32 | class MixerBlock(nn.Module): 33 | """Mixer block layer.""" 34 | tokens_mlp_dim: int 35 | channels_mlp_dim: int 36 | 37 | @nn.compact 38 | def __call__(self, x): 39 | y = nn.LayerNorm()(x) 40 | y = jnp.swapaxes(y, 1, 2) 41 | y = MlpBlock(self.tokens_mlp_dim, name='token_mixing')(y) 42 | y = jnp.swapaxes(y, 1, 2) 43 | x = x + y 44 | y = nn.LayerNorm()(x) 45 | return x + MlpBlock(self.channels_mlp_dim, name='channel_mixing')(y) 46 | 47 | 48 | class MlpMixer(nn.Module): 49 | """Mixer architecture.""" 50 | patches: Any 51 | num_classes: int 52 | num_blocks: int 53 | hidden_dim: int 54 | tokens_mlp_dim: int 55 | channels_mlp_dim: int 56 | model_name: Optional[str] = None 57 | 58 | @nn.compact 59 | def __call__(self, inputs, *, train, return_acts=False, return_layer=-1): 60 | del train 61 | x = nn.Conv(self.hidden_dim, self.patches.size, 62 | strides=self.patches.size, name='stem')(inputs) 63 | x = einops.rearrange(x, 'n h w c -> n (h w) c') 64 | num_layers = return_layer if return_acts and return_layer != -1 else self.num_blocks 65 | for _ in range(num_layers): 66 | x = MixerBlock(self.tokens_mlp_dim, self.channels_mlp_dim)(x) 67 | if return_acts: 68 | return (x, x, x, x) 69 | x = nn.LayerNorm(name='pre_head_layer_norm')(x) 70 | x = jnp.mean(x, axis=1) 71 | if self.num_classes: 72 | x = nn.Dense(self.num_classes, kernel_init=nn.initializers.zeros, 73 | name='head')(x) 74 | return x 75 | -------------------------------------------------------------------------------- /contrastive_text_image_learning/vision_transformer/vit_jax/models_resnet.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 Google LLC. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from typing import Callable, Sequence, TypeVar 16 | 17 | from flax import linen as nn 18 | import jax.numpy as jnp 19 | 20 | T = TypeVar('T') 21 | 22 | 23 | def weight_standardize(w, axis, eps): 24 | """Subtracts mean and divides by standard deviation.""" 25 | w = w - jnp.mean(w, axis=axis) 26 | w = w / (jnp.std(w, axis=axis) + eps) 27 | return w 28 | 29 | 30 | class StdConv(nn.Conv): 31 | """Convolution with weight standardization.""" 32 | 33 | def param(self, 34 | name: str, 35 | init_fn: Callable[..., T], 36 | *init_args) -> T: 37 | param = super().param(name, init_fn, *init_args) 38 | if name == 'kernel': 39 | param = weight_standardize(param, axis=[0, 1, 2], eps=1e-5) 40 | return param 41 | 42 | 43 | class ResidualUnit(nn.Module): 44 | """Bottleneck ResNet block.""" 45 | 46 | features: int 47 | strides: Sequence[int] = (1, 1) 48 | 49 | @nn.compact 50 | def __call__(self, x): 51 | needs_projection = ( 52 | x.shape[-1] != self.features * 4 or self.strides != (1, 1)) 53 | 54 | residual = x 55 | if needs_projection: 56 | residual = StdConv( 57 | features=self.features * 4, 58 | kernel_size=(1, 1), 59 | strides=self.strides, 60 | use_bias=False, 61 | name='conv_proj')( 62 | residual) 63 | residual = nn.GroupNorm(name='gn_proj')(residual) 64 | 65 | y = StdConv( 66 | features=self.features, 67 | kernel_size=(1, 1), 68 | use_bias=False, 69 | name='conv1')( 70 | x) 71 | y = nn.GroupNorm(name='gn1')(y) 72 | y = nn.relu(y) 73 | y = StdConv( 74 | features=self.features, 75 | kernel_size=(3, 3), 76 | strides=self.strides, 77 | use_bias=False, 78 | name='conv2')( 79 | y) 80 | y = nn.GroupNorm(name='gn2')(y) 81 | y = nn.relu(y) 82 | y = StdConv( 83 | features=self.features * 4, 84 | kernel_size=(1, 1), 85 | use_bias=False, 86 | name='conv3')( 87 | y) 88 | 89 | y = nn.GroupNorm(name='gn3', scale_init=nn.initializers.zeros)(y) 90 | y = nn.relu(residual + y) 91 | return y 92 | 93 | 94 | class ResNetStage(nn.Module): 95 | """A ResNet stage.""" 96 | 97 | block_size: Sequence[int] 98 | nout: int 99 | first_stride: Sequence[int] 100 | 101 | @nn.compact 102 | def __call__(self, x): 103 | x = ResidualUnit(self.nout, strides=self.first_stride, name='unit1')(x) 104 | for i in range(1, self.block_size): 105 | x = ResidualUnit(self.nout, strides=(1, 1), name=f'unit{i + 1}')(x) 106 | return x 107 | -------------------------------------------------------------------------------- /contrastive_text_image_learning/vision_transformer/vit_jax/models_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 Google LLC. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from absl.testing import absltest 16 | from absl.testing import parameterized 17 | 18 | import jax 19 | import jax.numpy as jnp 20 | 21 | from vit_jax import models 22 | from vit_jax.configs import models as config_lib 23 | 24 | 25 | MODEL_SIZES = { 26 | 'LiT-B16B': 195_871_489, 27 | 'LiT-B16B_2': 195_280_897, 28 | 'LiT-L16L': 638_443_521, 29 | 'LiT-L16S': 331_140_353, 30 | 'LiT-L16Ti': 311_913_089., 31 | 'Mixer-B_16': 59_880_472, 32 | 'Mixer-B_32': 60_293_428, 33 | 'Mixer-L_16': 208_196_168, 34 | 'R+ViT-Ti_16': 6_337_704, 35 | 'R26+ViT-B_32': 101_383_976, 36 | 'R26+ViT-S_32': 36_431_912, 37 | 'R50+ViT-B_16': 98_659_112, 38 | 'R50+ViT-L_32': 328_994_856, 39 | 'ViT-B_8': 86_576_872, 40 | 'ViT-B_16': 86_567_656, 41 | 'ViT-B_16-gap-norep': 86_566_120, 42 | 'ViT-B_32': 88_224_232, 43 | 'ViT-B_32-gap-norep': 88_222_696, 44 | 'ViT-H_14': 632_045_800, 45 | 'ViT-L_16': 304_326_632, 46 | 'ViT-L_32': 306_535_400, 47 | 'ViT-S_16': 22_050_664, 48 | 'ViT-S_16-gap-norep': 22_049_896, 49 | 'ViT-S_32': 22_878_952, 50 | 'ViT-S_32-gap-norep': 22_878_184, 51 | 'ViT-Ti_16': 5_717_416, 52 | 'testing': 21_390, 53 | 'testing-unpooled': 21_370, 54 | } 55 | 56 | 57 | class ModelsTest(parameterized.TestCase): 58 | 59 | def test_all_tested(self): 60 | self.assertEmpty(set(config_lib.MODEL_CONFIGS).difference(MODEL_SIZES)) 61 | 62 | @parameterized.parameters(*list(MODEL_SIZES.items())) 63 | def test_can_instantiate(self, name, size): 64 | rng = jax.random.PRNGKey(0) 65 | kw = {} if name.startswith('LiT-') else dict(num_classes=1_000) 66 | model = models.get_model(name, **kw) 67 | batch_size = 2 68 | images = jnp.ones([batch_size, 224, 224, 3], jnp.float32) 69 | if name.startswith('LiT-'): 70 | tokens = jnp.ones([batch_size, model.pp.max_len], jnp.int32) 71 | variables = model.init(rng, images=images, tokens=tokens) 72 | zimg, ztxt, _ = model.apply(variables, images=images, tokens=tokens) 73 | self.assertEqual(zimg.shape[0], batch_size) 74 | self.assertEqual(zimg.shape, ztxt.shape) 75 | else: 76 | variables = model.init(rng, images, train=False) 77 | outputs = model.apply(variables, images, train=False) 78 | if 'unpooled' in name: 79 | self.assertEqual((2, 196, 1000), outputs.shape) 80 | else: 81 | self.assertEqual((2, 1000), outputs.shape) 82 | param_count = sum(p.size for p in jax.tree_flatten(variables)[0]) 83 | self.assertEqual( 84 | size, param_count, 85 | f'Expected {name} to have {size} params, found {param_count}.') 86 | 87 | 88 | if __name__ == '__main__': 89 | absltest.main() 90 | -------------------------------------------------------------------------------- /contrastive_text_image_learning/vision_transformer/vit_jax/models_vit.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 Google LLC. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from typing import Any, Callable, Optional, Tuple, Type 16 | 17 | import flax.linen as nn 18 | import jax.numpy as jnp 19 | 20 | from vit_jax import models_resnet 21 | 22 | 23 | Array = Any 24 | PRNGKey = Any 25 | Shape = Tuple[int] 26 | Dtype = Any 27 | 28 | 29 | class IdentityLayer(nn.Module): 30 | """Identity layer, convenient for giving a name to an array.""" 31 | 32 | @nn.compact 33 | def __call__(self, x): 34 | return x 35 | 36 | 37 | class AddPositionEmbs(nn.Module): 38 | """Adds learned positional embeddings to the inputs. 39 | 40 | Attributes: 41 | posemb_init: positional embedding initializer. 42 | """ 43 | 44 | posemb_init: Callable[[PRNGKey, Shape, Dtype], Array] 45 | 46 | @nn.compact 47 | def __call__(self, inputs): 48 | """Applies the AddPositionEmbs module. 49 | 50 | Args: 51 | inputs: Inputs to the layer. 52 | 53 | Returns: 54 | Output tensor with shape `(bs, timesteps, in_dim)`. 55 | """ 56 | # inputs.shape is (batch_size, seq_len, emb_dim). 57 | assert inputs.ndim == 3, ('Number of dimensions should be 3,' 58 | ' but it is: %d' % inputs.ndim) 59 | pos_emb_shape = (1, inputs.shape[1], inputs.shape[2]) 60 | pe = self.param('pos_embedding', self.posemb_init, pos_emb_shape) 61 | return inputs + pe 62 | 63 | 64 | class MlpBlock(nn.Module): 65 | """Transformer MLP / feed-forward block.""" 66 | 67 | mlp_dim: int 68 | dtype: Dtype = jnp.float32 69 | out_dim: Optional[int] = None 70 | dropout_rate: float = 0.1 71 | kernel_init: Callable[[PRNGKey, Shape, Dtype], 72 | Array] = nn.initializers.xavier_uniform() 73 | bias_init: Callable[[PRNGKey, Shape, Dtype], 74 | Array] = nn.initializers.normal(stddev=1e-6) 75 | 76 | @nn.compact 77 | def __call__(self, inputs, *, deterministic, return_acts=False): 78 | """Applies Transformer MlpBlock module.""" 79 | actual_out_dim = inputs.shape[-1] if self.out_dim is None else self.out_dim 80 | x = nn.Dense( 81 | features=self.mlp_dim, 82 | dtype=self.dtype, 83 | kernel_init=self.kernel_init, 84 | bias_init=self.bias_init)( # pytype: disable=wrong-arg-types 85 | inputs) 86 | preacts = x 87 | x = nn.gelu(x) 88 | x = nn.Dropout(rate=self.dropout_rate)(x, deterministic=deterministic) 89 | output = nn.Dense( 90 | features=actual_out_dim, 91 | dtype=self.dtype, 92 | kernel_init=self.kernel_init, 93 | bias_init=self.bias_init)( # pytype: disable=wrong-arg-types 94 | x) 95 | output = nn.Dropout( 96 | rate=self.dropout_rate)( 97 | output, deterministic=deterministic) 98 | if return_acts: 99 | return preacts, output 100 | else: 101 | return output 102 | 103 | 104 | class Encoder1DBlock(nn.Module): 105 | """Transformer encoder layer. 106 | 107 | Attributes: 108 | inputs: input data. 109 | mlp_dim: dimension of the mlp on top of attention block. 110 | dtype: the dtype of the computation (default: float32). 111 | dropout_rate: dropout rate. 112 | attention_dropout_rate: dropout for attention heads. 113 | deterministic: bool, deterministic or not (to apply dropout). 114 | num_heads: Number of heads in nn.MultiHeadDotProductAttention 115 | """ 116 | 117 | mlp_dim: int 118 | num_heads: int 119 | dtype: Dtype = jnp.float32 120 | dropout_rate: float = 0.1 121 | attention_dropout_rate: float = 0.1 122 | 123 | @nn.compact 124 | def __call__(self, inputs, *, deterministic, return_acts=False): 125 | """Applies Encoder1DBlock module. 126 | 127 | Args: 128 | inputs: Inputs to the layer. 129 | deterministic: Dropout will not be applied when set to true. 130 | 131 | Returns: 132 | output after transformer encoder block. 133 | """ 134 | 135 | # Attention block. 136 | assert inputs.ndim == 3, f'Expected (batch, seq, hidden) got {inputs.shape}' 137 | x = nn.LayerNorm(dtype=self.dtype)(inputs) 138 | x = nn.MultiHeadDotProductAttention( 139 | dtype=self.dtype, 140 | kernel_init=nn.initializers.xavier_uniform(), 141 | broadcast_dropout=False, 142 | deterministic=deterministic, 143 | dropout_rate=self.attention_dropout_rate, 144 | num_heads=self.num_heads)( 145 | x, x) 146 | x = nn.Dropout(rate=self.dropout_rate)(x, deterministic=deterministic) 147 | x = x + inputs 148 | 149 | # MLP block. 150 | y = nn.LayerNorm(dtype=self.dtype)(x) 151 | y = MlpBlock( 152 | mlp_dim=self.mlp_dim, dtype=self.dtype, dropout_rate=self.dropout_rate)( 153 | y, deterministic=deterministic, return_acts=return_acts) 154 | if return_acts: 155 | return inputs, y[0], y[1], x + y[1] # i.e., (preattention, preacts, postmlp, output+identity) 156 | else: 157 | return x + y 158 | 159 | 160 | class Encoder(nn.Module): 161 | """Transformer Model Encoder for sequence to sequence translation. 162 | 163 | Attributes: 164 | num_layers: number of layers 165 | mlp_dim: dimension of the mlp on top of attention block 166 | num_heads: Number of heads in nn.MultiHeadDotProductAttention 167 | dropout_rate: dropout rate. 168 | attention_dropout_rate: dropout rate in self attention. 169 | """ 170 | 171 | num_layers: int 172 | mlp_dim: int 173 | num_heads: int 174 | dropout_rate: float = 0.1 175 | attention_dropout_rate: float = 0.1 176 | add_position_embedding: bool = True 177 | 178 | @nn.compact 179 | def __call__(self, x, *, train, return_acts=False, return_layer=-1): 180 | """Applies Transformer model on the inputs. 181 | 182 | Args: 183 | x: Inputs to the layer. 184 | train: Set to `True` when training. 185 | 186 | Returns: 187 | output of a transformer encoder. 188 | """ 189 | assert x.ndim == 3 # (batch, len, emb) 190 | 191 | if self.add_position_embedding: 192 | x = AddPositionEmbs( 193 | posemb_init=nn.initializers.normal(stddev=0.02), # from BERT. 194 | name='posembed_input')( 195 | x) 196 | x = nn.Dropout(rate=self.dropout_rate)(x, deterministic=not train) 197 | 198 | num_layers = return_layer if return_acts and return_layer != -1 else self.num_layers 199 | # Input Encoder 200 | for lyr in range(num_layers): 201 | x = Encoder1DBlock( 202 | mlp_dim=self.mlp_dim, 203 | dropout_rate=self.dropout_rate, 204 | attention_dropout_rate=self.attention_dropout_rate, 205 | name=f'encoderblock_{lyr}', 206 | num_heads=self.num_heads)( 207 | x, deterministic=not train, 208 | return_acts=False if lyr < num_layers-1 else return_acts) 209 | if return_acts: 210 | return x # in this case: x == (pre-attention, pre-GeLU, MLP output (GeLU -> Linear), MLP output + identity) 211 | encoded = nn.LayerNorm(name='encoder_norm')(x) 212 | 213 | return encoded 214 | 215 | 216 | class VisionTransformer(nn.Module): 217 | """VisionTransformer.""" 218 | 219 | num_classes: int 220 | patches: Any 221 | transformer: Any 222 | hidden_size: int 223 | resnet: Optional[Any] = None 224 | representation_size: Optional[int] = None 225 | classifier: str = 'token' 226 | head_bias_init: float = 0. 227 | encoder: Type[nn.Module] = Encoder 228 | model_name: Optional[str] = None 229 | 230 | @nn.compact 231 | def __call__(self, inputs, *, train, return_acts=False, return_layer=-1): 232 | 233 | x = inputs 234 | # (Possibly partial) ResNet root. 235 | if self.resnet is not None: 236 | width = int(64 * self.resnet.width_factor) 237 | 238 | # Root block. 239 | x = models_resnet.StdConv( 240 | features=width, 241 | kernel_size=(7, 7), 242 | strides=(2, 2), 243 | use_bias=False, 244 | name='conv_root')( 245 | x) 246 | x = nn.GroupNorm(name='gn_root')(x) 247 | preact = x 248 | x = nn.relu(x) 249 | postact = x 250 | 251 | # print((x > 0).mean(), ((x.reshape(x.shape[0], -1) > 0).sum(0) > 0).mean()) 252 | x = nn.max_pool(x, window_shape=(3, 3), strides=(2, 2), padding='SAME') 253 | # print((x > 0).mean(), ((x.reshape(x.shape[0], -1) > 0).sum(0) > 0).mean()) 254 | 255 | # ResNet stages (inactive for our model) 256 | if self.resnet.num_layers: 257 | x = models_resnet.ResNetStage( 258 | block_size=self.resnet.num_layers[0], 259 | nout=width, 260 | first_stride=(1, 1), 261 | name='block1')( 262 | x) 263 | for i, block_size in enumerate(self.resnet.num_layers[1:], 1): 264 | x = models_resnet.ResNetStage( 265 | block_size=block_size, 266 | nout=width * 2**i, 267 | first_stride=(2, 2), 268 | name=f'block{i + 1}')( 269 | x) 270 | 271 | n, h, w, c = x.shape 272 | 273 | # import ipdb;ipdb.set_trace() 274 | # We can merge s2d+emb into a single conv; it's the same. 275 | x = nn.Conv( 276 | features=self.hidden_size, 277 | kernel_size=self.patches.size, 278 | strides=self.patches.size, 279 | padding='VALID', 280 | name='embedding')( 281 | x) 282 | postconv2 = x 283 | 284 | if return_layer == 0: 285 | return (inputs, preact, postact, postconv2) 286 | 287 | # Here, x is a grid of embeddings. 288 | 289 | # (Possibly partial) Transformer. 290 | if self.transformer is not None: 291 | n, h, w, c = x.shape 292 | x = jnp.reshape(x, [n, h * w, c]) 293 | 294 | # If we want to add a class token, add it here. 295 | if self.classifier in ['token', 'token_unpooled']: 296 | cls = self.param('cls', nn.initializers.zeros, (1, 1, c)) 297 | cls = jnp.tile(cls, [n, 1, 1]) 298 | x = jnp.concatenate([cls, x], axis=1) 299 | 300 | x = self.encoder(name='Transformer', **self.transformer)(x, train=train, return_acts=return_acts, return_layer=return_layer) 301 | 302 | if self.classifier == 'token': 303 | if len(x) == 4: 304 | return x[0], x[1], x[2], x[3] 305 | else: 306 | x = x[:, 0] 307 | elif self.classifier == 'gap': 308 | if len(x) == 2: 309 | return jnp.mean(x[0], axis=list(range(1, x[0].ndim - 1))), jnp.mean(x[1], axis=list(range(1, x[1].ndim - 1))) 310 | else: 311 | x = jnp.mean(x, axis=list(range(1, x.ndim - 1))) # (1,) or (1,2) 312 | elif self.classifier in ['unpooled', 'token_unpooled']: 313 | pass 314 | else: 315 | raise ValueError(f'Invalid classifier={self.classifier}') 316 | 317 | if self.representation_size is not None: 318 | x = nn.Dense(features=self.representation_size, name='pre_logits')(x) 319 | x = nn.tanh(x) 320 | else: 321 | x = IdentityLayer(name='pre_logits')(x) 322 | 323 | if self.num_classes: 324 | x = nn.Dense( 325 | features=self.num_classes, 326 | name='head', 327 | kernel_init=nn.initializers.zeros, 328 | bias_init=nn.initializers.constant(self.head_bias_init))(x) 329 | return x 330 | -------------------------------------------------------------------------------- /contrastive_text_image_learning/vision_transformer/vit_jax/preprocess.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 Google LLC. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Preprocessing utilities for text/image models.""" 16 | 17 | import dataclasses 18 | 19 | import numpy as np 20 | import tensorflow as tf 21 | import tensorflow_text 22 | 23 | def get_tokenizer(tokenizer_name): 24 | """Returns a tokenizer specified by name ("bert" or "sentencpiece").""" 25 | return { 26 | 'bert': BertTokenizer, 27 | 'sentencepiece': SentencepieceTokenizer, 28 | }[tokenizer_name] 29 | 30 | 31 | @dataclasses.dataclass(frozen=True) 32 | class BertTokenizer: 33 | """BERT tokenizer with prepended CLS token and fixed sequence length. 34 | 35 | This class can be used to tokenize batches of text tokens to numpy arrays 36 | (by calling `__call__()`), or as part of a TensorFlow preprocessing graph 37 | (via the method `preprocess_tf()`). 38 | 39 | Attributes: 40 | vocab_path: Path pointing to the vocabulary file. Can be any path string 41 | that is understood by `tf.io.gfile`. 42 | max_len: Length of tokenized sequences. If the provided texts result in 43 | fewer tokens, then the sequence is zero-padded. If the provided texts 44 | result in more tokens, then the tokens are clipped. 45 | cls_token: Will be set during class construction. 46 | """ 47 | 48 | vocab_path: str 49 | max_len: int 50 | cls_token: int = dataclasses.field(init=False) 51 | 52 | _tokenizer: tensorflow_text.BertTokenizer = dataclasses.field(init=False) 53 | 54 | def __post_init__(self): 55 | tokenizer = tensorflow_text.BertTokenizer( 56 | self.vocab_path, token_out_type=tf.int32, lower_case=True) 57 | with tf.io.gfile.GFile(self.vocab_path) as f: 58 | vocab = f.read().split('\n') 59 | cls_token = vocab.index('[CLS]') 60 | 61 | # Work-around for frozen dataclasses: 62 | # https://stackoverflow.com/questions/53756788 63 | object.__setattr__(self, 'cls_token', cls_token) 64 | object.__setattr__(self, '_tokenizer', tokenizer) 65 | 66 | def preprocess_tf(self, text): 67 | """Tokenizes a single text as part of a TensorFlow graph.""" 68 | return self._preprocess(text[None])[0] 69 | 70 | def _preprocess(self, texts): 71 | token_ids = self._tokenizer.tokenize(texts) 72 | tokens, mask = tensorflow_text.pad_model_inputs(token_ids, self.max_len - 1) 73 | del mask # Recovered from zero padding in model. 74 | count = tf.shape(tokens)[0] 75 | return tf.concat([tf.fill([count, 1], self.cls_token), tokens], axis=1) 76 | 77 | def __call__(self, texts): 78 | """Tokenizes a batch of texts to a numpy array.""" 79 | return self._preprocess(tf.constant(texts)).numpy() 80 | 81 | 82 | @dataclasses.dataclass(frozen=True) 83 | class SentencepieceTokenizer: 84 | """SentencePiece tokenizer with sticky eos. 85 | 86 | Models that use this tokanizer usually use the *last* token, which is 87 | guaranteed to be the "" token (even if tokens are capped to `max_len`). 88 | The same token is used for padding (and exposed as `eos_token`). 89 | 90 | This class can be used to tokenize batches of text tokens to numpy arrays 91 | (by calling `__call__()`), or as part of a TensorFlow preprocessing graph 92 | (via the method `preprocess_tf()`). 93 | 94 | Attributes: 95 | vocab_path: Path pointing to the vocabulary file. Can be any path string 96 | that is understood by `tf.io.gfile`. 97 | max_len: Length of tokenized sequences. If the provided texts result in 98 | fewer tokens, then the sequence is zero-padded. If the provided texts 99 | result in more tokens, then the tokens are clipped. 100 | eos_token: Token used for padding. Last token is guaranteed to be padded. 101 | """ 102 | 103 | vocab_path: str 104 | max_len: int 105 | eos_token: int = dataclasses.field(init=False) 106 | 107 | _tokenizer: tensorflow_text.BertTokenizer = dataclasses.field(init=False) 108 | 109 | def __post_init__(self): 110 | tokenizer = tensorflow_text.SentencepieceTokenizer( 111 | model=tf.io.gfile.GFile(self.vocab_path, 'rb').read(), add_eos=True) 112 | eos_token = tokenizer.string_to_id('') 113 | 114 | # Work-around for frozen dataclasses: 115 | # https://stackoverflow.com/questions/53756788 116 | object.__setattr__(self, 'eos_token', eos_token) 117 | object.__setattr__(self, '_tokenizer', tokenizer) 118 | 119 | def preprocess_tf(self, text): 120 | """Tokenizes a single text as part of a TensorFlow graph.""" 121 | tokens = self._tokenizer.tokenize(text) 122 | tokens = tokens[:self.max_len - 1] # to guarantee eos at end 123 | return tf.pad( 124 | tokens, [(0, self.max_len - tf.shape(tokens)[0])], 125 | constant_values=self.eos_token) 126 | 127 | def __call__(self, texts): 128 | """Tokenizes a batch of texts to a numpy array.""" 129 | return tf.stack([self.preprocess_tf(text) for text in texts]).numpy() 130 | 131 | 132 | @dataclasses.dataclass(frozen=True) 133 | class PreprocessImages: 134 | """Resizes images and sets value range to [-1, 1]. 135 | 136 | This class can be used to tokenize batches of text tokens to numpy arrays 137 | (by calling `__call__()`), or as part of a TensorFlow preprocessing graph 138 | (via the method `preprocess_tf()`). 139 | 140 | Attributes: 141 | size: Target size of images. 142 | crop: If set to true, then the image will first be resized maintaining the 143 | original aspect ratio, and then a central crop of that resized image will 144 | be returned. 145 | """ 146 | size: int 147 | crop: bool = False 148 | 149 | def _resize_small(self, image): # pylint: disable=missing-docstring 150 | h, w = tf.shape(image)[0], tf.shape(image)[1] 151 | 152 | # Figure out the necessary h/w. 153 | ratio = ( 154 | tf.cast(self.size, tf.float32) / 155 | tf.cast(tf.minimum(h, w), tf.float32)) 156 | h = tf.cast(tf.round(tf.cast(h, tf.float32) * ratio), tf.int32) 157 | w = tf.cast(tf.round(tf.cast(w, tf.float32) * ratio), tf.int32) 158 | 159 | return tf.image.resize(image, (h, w), method='bilinear') 160 | 161 | def _crop(self, image): 162 | h, w = self.size, self.size 163 | dy = (tf.shape(image)[0] - h) // 2 164 | dx = (tf.shape(image)[1] - w) // 2 165 | return tf.image.crop_to_bounding_box(image, dy, dx, h, w) 166 | 167 | def _resize(self, image): 168 | return tf.image.resize( 169 | image, size=[self.size, self.size], method='bilinear') 170 | 171 | def _value_range(self, image): 172 | image = tf.cast(image, tf.float32) / 255 173 | return -1 + image * 2 174 | 175 | def preprocess_tf(self, image): 176 | """Resizes a single image as part of a TensorFlowg graph.""" 177 | assert image.dtype == tf.uint8 178 | if self.crop: 179 | image = self._resize_small(image) 180 | image = self._crop(image) 181 | else: 182 | image = self._resize(image) 183 | image = tf.cast(image, tf.uint8) 184 | return self._value_range(image) 185 | 186 | def __call__(self, images): 187 | """Resizes a sequence of images, returns a numpy array.""" 188 | return np.stack([ 189 | self.preprocess_tf(tf.constant(image)) for image in images 190 | ]) 191 | 192 | 193 | def get_pp(*, tokenizer_name, vocab_path, max_len, size, crop=False): 194 | """Returns preprocessing function for "image" and "text" features. 195 | 196 | The returned function can directly be used with `tf.data.Dataset.map()`. 197 | If either the text feature (feature key "text") or the image feature (feature 198 | key "image") are not found, then they will be left untouched. 199 | 200 | Note that the "image" feature is overwritten with the resized image, but the 201 | "text" feature is tokenized into a new feature "tokens". 202 | 203 | Args: 204 | tokenizer_name: Name of tokenizer (either "bert", or "sentencepiece"). 205 | vocab_path: Argument passed to tokenizer. 206 | max_len: Argument passed to tokenizer. 207 | size: Argument passed to `PreprocessImages`. 208 | crop: Argument passed to `PreprocessImages`. 209 | """ 210 | tokenizer_class = get_tokenizer(tokenizer_name) 211 | tokenizer = tokenizer_class(vocab_path=vocab_path, max_len=max_len) 212 | preprocess_images = PreprocessImages(size=size, crop=crop) 213 | 214 | def pp(features): 215 | features = {**features} 216 | if 'image' in features: 217 | features['image'] = preprocess_images.preprocess_tf(features['image']) 218 | if 'text' in features: 219 | features['tokens'] = tokenizer.preprocess_tf(features['text']) 220 | return features 221 | 222 | return pp 223 | -------------------------------------------------------------------------------- /contrastive_text_image_learning/vision_transformer/vit_jax/preprocess_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 Google LLC. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import contextlib 16 | import tempfile 17 | from unittest import mock 18 | 19 | from absl.testing import absltest 20 | import numpy as np 21 | import tensorflow as tf 22 | 23 | from vit_jax import preprocess 24 | 25 | VOCAB = """[PAD] 26 | [CLS] 27 | some 28 | test 29 | words""" 30 | 31 | 32 | @contextlib.contextmanager 33 | def _create_vocab(): 34 | with tempfile.NamedTemporaryFile('w') as f: 35 | f.write(VOCAB) 36 | f.flush() 37 | yield f.name 38 | 39 | 40 | class PreprocessTest(absltest.TestCase): 41 | 42 | def test_bert_tokenizer(self): 43 | with _create_vocab() as vocab_path: 44 | tokenizer = preprocess.BertTokenizer(vocab_path=vocab_path, max_len=3) 45 | 46 | tokens = tokenizer(['some', 'test', 'words', 'xxx']) 47 | np.testing.assert_equal(tokens, [ 48 | [1, 2, 0], 49 | [1, 3, 0], 50 | [1, 4, 0], 51 | [1, 5, 0], 52 | ]) 53 | 54 | @mock.patch('tensorflow_text.SentencepieceTokenizer') 55 | @mock.patch('tensorflow.io.gfile.GFile') 56 | def test_sentencepiece_tokenizer(self, gfile_patch, tokenizer_patch): 57 | gfile_patch.return_value.read.return_value = 'test vocab' 58 | eos_token = 7 59 | tokenizer_patch.return_value.string_to_id.return_value = eos_token 60 | tokenizer_patch.return_value.tokenize.side_effect = ( 61 | tf.constant([1, eos_token], tf.int32), 62 | tf.constant([2, 3, eos_token], tf.int32), 63 | tf.constant([4, 5, 6, eos_token], tf.int32), 64 | ) 65 | 66 | tokenizer = preprocess.SentencepieceTokenizer( 67 | vocab_path='test_path', max_len=3) 68 | gfile_patch.assert_called_once_with('test_path', 'rb') 69 | tokenizer_patch.assert_called_once_with(model='test vocab', add_eos=True) 70 | tokenizer_patch.return_value.string_to_id.assert_called_once_with('') 71 | 72 | tokens = tokenizer(['some', 'test', 'words']) 73 | tokenizer_patch.return_value.tokenize.assert_has_calls( 74 | (mock.call('some'), mock.call('test'), mock.call('words'))) 75 | np.testing.assert_equal(tokens, [ 76 | [1, eos_token, eos_token], 77 | [2, 3, eos_token], 78 | [4, 5, eos_token], 79 | ]) 80 | 81 | def test_preprocess_images(self): 82 | # white images with black border 83 | img1 = 255 * np.concatenate([ # portrait image 84 | np.zeros([2, 10, 3], np.uint8), 85 | np.ones([12, 10, 3], np.uint8), 86 | np.zeros([2, 10, 3], np.uint8), 87 | ], axis=0) 88 | img2 = 255 * np.concatenate([ # landscape image 89 | np.zeros([10, 2, 3], np.uint8), 90 | np.ones([10, 12, 3], np.uint8), 91 | np.zeros([10, 2, 3], np.uint8), 92 | ], axis=1) 93 | 94 | preprocess_images = preprocess.PreprocessImages(size=4, crop=False) 95 | imgs = preprocess_images([img1, img2]) 96 | self.assertEqual(imgs.shape, (2, 4, 4, 3)) 97 | self.assertLess(imgs.mean(), 1.0) # borders resized 98 | 99 | preprocess_images = preprocess.PreprocessImages(size=4, crop=True) 100 | imgs = preprocess_images([img1, img2]) 101 | self.assertEqual(imgs.shape, (2, 4, 4, 3)) 102 | self.assertEqual(imgs.mean(), 1.0) # borders cropped 103 | 104 | def test_pp_bert(self): 105 | with _create_vocab() as vocab_path: 106 | pp = preprocess.get_pp( 107 | tokenizer_name='bert', vocab_path=vocab_path, max_len=3, size=4) 108 | 109 | ds = tf.data.Dataset.from_tensor_slices({ 110 | 'text': 111 | tf.constant(['test', 'test']), 112 | 'image': [ 113 | tf.ones([10, 10, 3], tf.uint8), 114 | tf.ones([10, 10, 3], tf.uint8) 115 | ], 116 | }) 117 | 118 | b = next(iter(ds.map(pp).batch(2).as_numpy_iterator())) 119 | dtypes_shapes = {k: (v.dtype, v.shape) for k, v in b.items()} 120 | np.testing.assert_equal(dtypes_shapes, { 121 | 'image': (np.float32, (2, 4, 4, 3)), 122 | 'text': (object, (2,)), 123 | 'tokens': (np.int32, (2, 3)) 124 | }) 125 | 126 | @mock.patch('tensorflow_text.SentencepieceTokenizer') 127 | @mock.patch('tensorflow.io.gfile.GFile') 128 | def test_pp_sentencepiece(self, gfile_patch, tokenizer_patch): 129 | eos_token = 7 130 | gfile_patch.return_value.read.return_value = 'test vocab' 131 | tokenizer_patch.return_value.string_to_id.return_value = eos_token 132 | tokenizer_patch.return_value.tokenize.side_effect = ( 133 | tf.constant([1, eos_token], tf.int32), 134 | tf.constant([2, 3, eos_token], tf.int32), 135 | ) 136 | pp = preprocess.get_pp( 137 | tokenizer_name='sentencepiece', 138 | vocab_path='test', 139 | max_len=3, 140 | size=4) 141 | 142 | ds = tf.data.Dataset.from_tensor_slices({ 143 | 'text': 144 | tf.constant(['test', 'test']), 145 | 'image': [ 146 | tf.ones([10, 10, 3], tf.uint8), 147 | tf.ones([10, 10, 3], tf.uint8) 148 | ], 149 | }) 150 | 151 | b = next(iter(ds.map(pp).batch(2).as_numpy_iterator())) 152 | dtypes_shapes = {k: (v.dtype, v.shape) for k, v in b.items()} 153 | np.testing.assert_equal(dtypes_shapes, { 154 | 'image': (np.float32, (2, 4, 4, 3)), 155 | 'text': (object, (2,)), 156 | 'tokens': (np.int32, (2, 3)) 157 | }) 158 | 159 | if __name__ == '__main__': 160 | absltest.main() 161 | -------------------------------------------------------------------------------- /contrastive_text_image_learning/vision_transformer/vit_jax/requirements-tpu.txt: -------------------------------------------------------------------------------- 1 | absl-py>=0.12.0 2 | chex>=0.0.7 3 | clu>=0.0.3 4 | einops>=0.3.0 5 | flax>=0.4.1 6 | git+https://github.com/google/flaxformer 7 | 8 | jax[tpu]>=0.2.16 9 | --find-links https://storage.googleapis.com/jax-releases/libtpu_releases.html 10 | 11 | ml-collections==0.1.0 12 | numpy>=1.19.5 13 | pandas>=1.1.0 14 | tensorflow-datasets>=4.0.1 15 | tensorflow-probability>=0.11.1 16 | tensorflow-text>=2.9.0 -------------------------------------------------------------------------------- /contrastive_text_image_learning/vision_transformer/vit_jax/requirements.txt: -------------------------------------------------------------------------------- 1 | absl-py>=0.12.0 2 | chex>=0.0.7 3 | clu>=0.0.3 4 | einops>=0.3.0 5 | flax>=0.6.4 6 | git+https://github.com/google/flaxformer 7 | 8 | jax[cuda11_cudnn82]>=0.4.2 9 | --find-links https://storage.googleapis.com/jax-releases/jax_cuda_releases.html 10 | 11 | ml-collections>=0.1.0 12 | numpy>=1.19.5 13 | pandas>=1.1.0 14 | tensorflow-cpu>=2.4.0 # Using tensorflow-cpu to have all GPU memory for JAX. 15 | tensorflow-datasets>=4.0.1 16 | tensorflow-probability>=0.11.1 17 | tensorflow-text>=2.9.0 18 | -------------------------------------------------------------------------------- /contrastive_text_image_learning/vision_transformer/vit_jax/test_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 Google LLC. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Utility code to create fake pre-trained checkpoints.""" 16 | 17 | import os 18 | 19 | import dataclasses 20 | import flax 21 | import jax 22 | import jax.numpy as jnp 23 | import numpy as np 24 | 25 | from vit_jax import models 26 | 27 | 28 | def _traverse_with_names(tree): 29 | """Traverses nested dicts/dataclasses and emits (leaf_name, leaf_val).""" 30 | if dataclasses.is_dataclass(tree): 31 | tree = flax.serialization.to_state_dict(tree) 32 | if isinstance(tree, dict) or isinstance(tree, flax.core.FrozenDict): 33 | keys = sorted(tree.keys()) 34 | for key in keys: 35 | for path, v in _traverse_with_names(tree[key]): 36 | yield (key + '/' + path).rstrip('/'), v 37 | else: 38 | yield '', tree 39 | 40 | 41 | def _tree_flatten_with_names(tree): 42 | """Populates tree_flatten with leaf names. 43 | 44 | This function populates output of tree_flatten with leaf names, using a 45 | custom traversal that produces names is provided. The custom traversal does 46 | NOT have to traverse tree in the same order as jax, as we take care of 47 | automatically aligning jax' and custom traversals. 48 | 49 | Args: 50 | tree: python tree. 51 | 52 | Returns: 53 | A list of values with names: [(name, value), ...] 54 | """ 55 | vals, tree_def = jax.tree_flatten(tree) 56 | 57 | # "Fake" token tree that is use to track jax internal tree traversal and 58 | # adjust our custom tree traversal to be compatible with it. 59 | tokens = range(len(vals)) 60 | token_tree = tree_def.unflatten(tokens) 61 | val_names, perm = zip(*_traverse_with_names(token_tree)) 62 | inv_perm = np.argsort(perm) 63 | 64 | # Custom traversal should visit the same number of leaves. 65 | assert len(val_names) == len(vals) 66 | 67 | return [(val_names[i], v) for i, v in zip(inv_perm, vals)], tree_def 68 | 69 | 70 | def _save(data, path): 71 | """Util for checkpointing: saves jax pytree objects to the disk.""" 72 | names_and_vals, _ = _tree_flatten_with_names(data) 73 | os.makedirs(os.path.dirname(path), exist_ok=True) 74 | with open(path, 'wb') as f: 75 | np.savez(f, **{k: v for k, v in names_and_vals}) 76 | 77 | 78 | def create_checkpoint(model_config, path): 79 | """Initializes model and stores weights in specified path.""" 80 | model = models.VisionTransformer(num_classes=1, **model_config) 81 | variables = model.init( 82 | jax.random.PRNGKey(0), 83 | jnp.ones([1, 16, 16, 3], jnp.float32), 84 | train=False, 85 | ) 86 | _save(variables['params'], path) 87 | -------------------------------------------------------------------------------- /contrastive_text_image_learning/vision_transformer/vit_jax/train.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 Google LLC. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import functools 16 | import os 17 | import time 18 | 19 | from absl import logging 20 | from clu import metric_writers 21 | from clu import periodic_actions 22 | import flax 23 | from flax.training import checkpoints as flax_checkpoints 24 | import jax 25 | import jax.numpy as jnp 26 | import ml_collections 27 | import numpy as np 28 | import optax 29 | import tensorflow as tf 30 | 31 | from vit_jax import checkpoint 32 | from vit_jax import input_pipeline 33 | from vit_jax import models 34 | from vit_jax import utils 35 | 36 | 37 | def make_update_fn(*, apply_fn, accum_steps, tx): 38 | """Returns update step for data parallel training.""" 39 | 40 | def update_fn(params, opt_state, batch, rng): 41 | 42 | _, new_rng = jax.random.split(rng) 43 | # Bind the rng key to the device id (which is unique across hosts) 44 | # Note: This is only used for multi-host training (i.e. multiple computers 45 | # each with multiple accelerators). 46 | dropout_rng = jax.random.fold_in(rng, jax.lax.axis_index('batch')) 47 | 48 | def cross_entropy_loss(*, logits, labels): 49 | logp = jax.nn.log_softmax(logits) 50 | return -jnp.mean(jnp.sum(logp * labels, axis=1)) 51 | 52 | def loss_fn(params, images, labels): 53 | logits = apply_fn( 54 | dict(params=params), 55 | rngs=dict(dropout=dropout_rng), 56 | inputs=images, 57 | train=True) 58 | return cross_entropy_loss(logits=logits, labels=labels) 59 | 60 | l, g = utils.accumulate_gradient( 61 | jax.value_and_grad(loss_fn), params, batch['image'], batch['label'], 62 | accum_steps) 63 | g = jax.tree_map(lambda x: jax.lax.pmean(x, axis_name='batch'), g) 64 | updates, opt_state = tx.update(g, opt_state) 65 | params = optax.apply_updates(params, updates) 66 | l = jax.lax.pmean(l, axis_name='batch') 67 | 68 | return params, opt_state, l, new_rng 69 | 70 | return jax.pmap(update_fn, axis_name='batch', donate_argnums=(0,)) 71 | 72 | 73 | def train_and_evaluate(config: ml_collections.ConfigDict, workdir: str): 74 | """Runs training interleaved with evaluation.""" 75 | 76 | # Setup input pipeline 77 | dataset_info = input_pipeline.get_dataset_info(config.dataset, 'train') 78 | 79 | ds_train, ds_test = input_pipeline.get_datasets(config) 80 | batch = next(iter(ds_train)) 81 | logging.info(ds_train) 82 | logging.info(ds_test) 83 | 84 | # Build VisionTransformer architecture 85 | model_cls = {'ViT': models.VisionTransformer, 86 | 'Mixer': models.MlpMixer}[config.get('model_type', 'ViT')] 87 | model = model_cls(num_classes=dataset_info['num_classes'], **config.model) 88 | 89 | def init_model(): 90 | return model.init( 91 | jax.random.PRNGKey(0), 92 | # Discard the "num_local_devices" dimension for initialization. 93 | jnp.ones(batch['image'].shape[1:], batch['image'].dtype.name), 94 | train=False) 95 | 96 | # Use JIT to make sure params reside in CPU memory. 97 | variables = jax.jit(init_model, backend='cpu')() 98 | 99 | model_or_filename = config.get('model_or_filename') 100 | if model_or_filename: 101 | # Loading model from repo published with "How to train your ViT? Data, 102 | # Augmentation, and Regularization in Vision Transformers" paper. 103 | # https://arxiv.org/abs/2106.10270 104 | if '-' in model_or_filename: 105 | filename = model_or_filename 106 | else: 107 | # Select best checkpoint from i21k pretraining by final upstream 108 | # validation accuracy. 109 | df = checkpoint.get_augreg_df(directory=config.pretrained_dir) 110 | sel = df.filename.apply( 111 | lambda filename: filename.split('-')[0] == model_or_filename) 112 | best = df.loc[sel].query('ds=="i21k"').sort_values('final_val').iloc[-1] 113 | filename = best.filename 114 | logging.info('Selected fillename="%s" for "%s" with final_val=%.3f', 115 | filename, model_or_filename, best.final_val) 116 | pretrained_path = os.path.join(config.pretrained_dir, 117 | f'{config.model.model_name}.npz') 118 | else: 119 | # ViT / Mixer papers 120 | filename = config.model.model_name 121 | 122 | pretrained_path = os.path.join(config.pretrained_dir, f'{filename}.npz') 123 | if not tf.io.gfile.exists(pretrained_path): 124 | raise ValueError( 125 | f'Could not find "{pretrained_path}" - you can download models from ' 126 | '"gs://vit_models/imagenet21k" or directly set ' 127 | '--config.pretrained_dir="gs://vit_models/imagenet21k".') 128 | params = checkpoint.load_pretrained( 129 | pretrained_path=pretrained_path, 130 | init_params=variables['params'], 131 | model_config=config.model) 132 | 133 | total_steps = config.total_steps 134 | 135 | lr_fn = utils.create_learning_rate_schedule(total_steps, config.base_lr, 136 | config.decay_type, 137 | config.warmup_steps) 138 | tx = optax.chain( 139 | optax.clip_by_global_norm(config.grad_norm_clip), 140 | optax.sgd( 141 | learning_rate=lr_fn, 142 | momentum=0.9, 143 | accumulator_dtype='bfloat16', 144 | ), 145 | ) 146 | 147 | update_fn_repl = make_update_fn( 148 | apply_fn=model.apply, accum_steps=config.accum_steps, tx=tx) 149 | infer_fn_repl = jax.pmap(functools.partial(model.apply, train=False)) 150 | 151 | initial_step = 1 152 | opt_state = tx.init(params) 153 | params, opt_state, initial_step = flax_checkpoints.restore_checkpoint( 154 | workdir, (params, opt_state, initial_step)) 155 | logging.info('Will start/continue training at initial_step=%d', initial_step) 156 | 157 | params_repl, opt_state_repl = flax.jax_utils.replicate((params, opt_state)) 158 | 159 | # Delete references to the objects that are not needed anymore 160 | del opt_state 161 | del params 162 | 163 | # Prepare the learning-rate and pre-fetch it to device to avoid delays. 164 | update_rng_repl = flax.jax_utils.replicate(jax.random.PRNGKey(0)) 165 | 166 | # Setup metric writer & hooks. 167 | writer = metric_writers.create_default_writer(workdir, asynchronous=False) 168 | writer.write_hparams(config.to_dict()) 169 | hooks = [ 170 | periodic_actions.Profile(logdir=workdir), 171 | periodic_actions.ReportProgress( 172 | num_train_steps=total_steps, writer=writer), 173 | ] 174 | 175 | # Run training loop 176 | logging.info('Starting training loop; initial compile can take a while...') 177 | t0 = lt0 = time.time() 178 | lstep = initial_step 179 | for step, batch in zip( 180 | range(initial_step, total_steps + 1), 181 | input_pipeline.prefetch(ds_train, config.prefetch)): 182 | 183 | with jax.profiler.StepTraceAnnotation('train', step_num=step): 184 | params_repl, opt_state_repl, loss_repl, update_rng_repl = update_fn_repl( 185 | params_repl, opt_state_repl, batch, update_rng_repl) 186 | 187 | for hook in hooks: 188 | hook(step) 189 | 190 | if step == initial_step: 191 | logging.info('First step took %.1f seconds.', time.time() - t0) 192 | t0 = time.time() 193 | lt0, lstep = time.time(), step 194 | 195 | # Report training metrics 196 | if config.progress_every and step % config.progress_every == 0: 197 | img_sec_core_train = (config.batch * (step - lstep) / 198 | (time.time() - lt0)) / jax.device_count() 199 | lt0, lstep = time.time(), step 200 | writer.write_scalars( 201 | step, 202 | dict( 203 | train_loss=float(flax.jax_utils.unreplicate(loss_repl)), 204 | img_sec_core_train=img_sec_core_train)) 205 | done = step / total_steps 206 | logging.info(f'Step: {step}/{total_steps} {100*done:.1f}%, ' # pylint: disable=logging-fstring-interpolation 207 | f'img/sec/core: {img_sec_core_train:.1f}, ' 208 | f'ETA: {(time.time()-t0)/done*(1-done)/3600:.2f}h') 209 | 210 | # Run evaluation 211 | if ((config.eval_every and step % config.eval_every == 0) or 212 | (step == total_steps)): 213 | 214 | accuracies = [] 215 | lt0 = time.time() 216 | for test_batch in input_pipeline.prefetch(ds_test, config.prefetch): 217 | logits = infer_fn_repl( 218 | dict(params=params_repl), test_batch['image']) 219 | accuracies.append( 220 | (np.argmax(logits, 221 | axis=-1) == np.argmax(test_batch['label'], 222 | axis=-1)).mean()) 223 | accuracy_test = np.mean(accuracies) 224 | img_sec_core_test = ( 225 | config.batch_eval * ds_test.cardinality().numpy() / 226 | (time.time() - lt0) / jax.device_count()) 227 | lt0 = time.time() 228 | 229 | lr = float(lr_fn(step)) 230 | logging.info(f'Step: {step} ' # pylint: disable=logging-fstring-interpolation 231 | f'Learning rate: {lr:.7f}, ' 232 | f'Test accuracy: {accuracy_test:0.5f}, ' 233 | f'img/sec/core: {img_sec_core_test:.1f}') 234 | writer.write_scalars( 235 | step, 236 | dict( 237 | accuracy_test=accuracy_test, 238 | lr=lr, 239 | img_sec_core_test=img_sec_core_test)) 240 | 241 | # Store checkpoint. 242 | if ((config.checkpoint_every and step % config.eval_every == 0) or 243 | step == total_steps): 244 | checkpoint_path = flax_checkpoints.save_checkpoint( 245 | workdir, (flax.jax_utils.unreplicate(params_repl), 246 | flax.jax_utils.unreplicate(opt_state_repl), step), step) 247 | logging.info('Stored checkpoint at step %d to "%s"', step, 248 | checkpoint_path) 249 | 250 | return flax.jax_utils.unreplicate(params_repl) 251 | -------------------------------------------------------------------------------- /contrastive_text_image_learning/vision_transformer/vit_jax/train_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 Google LLC. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import os 16 | import tempfile 17 | 18 | from absl.testing import absltest 19 | from absl.testing import parameterized 20 | import ml_collections 21 | import tensorflow_datasets as tfds 22 | 23 | from vit_jax import test_utils 24 | from vit_jax import train 25 | from vit_jax.configs import common 26 | from vit_jax.configs import models 27 | 28 | # from PIL import Image 29 | # import numpy as np 30 | # Image.fromarray(np.array([[[0, 0, 0]]], np.uint8)).save('black1px.jpg') 31 | # print(repr(file('black1px.jpg', 'rb').read())) 32 | JPG_BLACK_1PX = (b'\xff\xd8\xff\xe0\x00\x10JFIF\x00\x01\x01\x00\x00\x01\x00\x01\x00\x00\xff\xdb\x00C\x00\x08\x06\x06\x07\x06\x05\x08\x07\x07\x07\t\t\x08\n\x0c\x14\r\x0c\x0b\x0b\x0c\x19\x12\x13\x0f\x14\x1d\x1a\x1f\x1e\x1d\x1a\x1c\x1c' 33 | b' $.\' ' 34 | b'",#\x1c\x1c(7),01444\x1f\'9=82<.342\xff\xdb\x00C\x01\t\t\t\x0c\x0b\x0c\x18\r\r\x182!\x1c!22222222222222222222222222222222222222222222222222\xff\xc0\x00\x11\x08\x00\x01\x00\x01\x03\x01"\x00\x02\x11\x01\x03\x11\x01\xff\xc4\x00\x1f\x00\x00\x01\x05\x01\x01\x01\x01\x01\x01\x00\x00\x00\x00\x00\x00\x00\x00\x01\x02\x03\x04\x05\x06\x07\x08\t\n\x0b\xff\xc4\x00\xb5\x10\x00\x02\x01\x03\x03\x02\x04\x03\x05\x05\x04\x04\x00\x00\x01}\x01\x02\x03\x00\x04\x11\x05\x12!1A\x06\x13Qa\x07"q\x142\x81\x91\xa1\x08#B\xb1\xc1\x15R\xd1\xf0$3br\x82\t\n\x16\x17\x18\x19\x1a%&\'()*456789:CDEFGHIJSTUVWXYZcdefghijstuvwxyz\x83\x84\x85\x86\x87\x88\x89\x8a\x92\x93\x94\x95\x96\x97\x98\x99\x9a\xa2\xa3\xa4\xa5\xa6\xa7\xa8\xa9\xaa\xb2\xb3\xb4\xb5\xb6\xb7\xb8\xb9\xba\xc2\xc3\xc4\xc5\xc6\xc7\xc8\xc9\xca\xd2\xd3\xd4\xd5\xd6\xd7\xd8\xd9\xda\xe1\xe2\xe3\xe4\xe5\xe6\xe7\xe8\xe9\xea\xf1\xf2\xf3\xf4\xf5\xf6\xf7\xf8\xf9\xfa\xff\xc4\x00\x1f\x01\x00\x03\x01\x01\x01\x01\x01\x01\x01\x01\x01\x00\x00\x00\x00\x00\x00\x01\x02\x03\x04\x05\x06\x07\x08\t\n\x0b\xff\xc4\x00\xb5\x11\x00\x02\x01\x02\x04\x04\x03\x04\x07\x05\x04\x04\x00\x01\x02w\x00\x01\x02\x03\x11\x04\x05!1\x06\x12AQ\x07aq\x13"2\x81\x08\x14B\x91\xa1\xb1\xc1\t#3R\xf0\x15br\xd1\n\x16$4\xe1%\xf1\x17\x18\x19\x1a&\'()*56789:CDEFGHIJSTUVWXYZcdefghijstuvwxyz\x82\x83\x84\x85\x86\x87\x88\x89\x8a\x92\x93\x94\x95\x96\x97\x98\x99\x9a\xa2\xa3\xa4\xa5\xa6\xa7\xa8\xa9\xaa\xb2\xb3\xb4\xb5\xb6\xb7\xb8\xb9\xba\xc2\xc3\xc4\xc5\xc6\xc7\xc8\xc9\xca\xd2\xd3\xd4\xd5\xd6\xd7\xd8\xd9\xda\xe2\xe3\xe4\xe5\xe6\xe7\xe8\xe9\xea\xf2\xf3\xf4\xf5\xf6\xf7\xf8\xf9\xfa\xff\xda\x00\x0c\x03\x01\x00\x02\x11\x03\x11\x00?\x00\xf9\xfe\x8a(\xa0\x0f\xff\xd9') # pylint: disable=line-too-long 35 | 36 | 37 | class TrainTest(parameterized.TestCase): 38 | 39 | @parameterized.named_parameters( 40 | ('tfds', 'tfds'), 41 | ('directory', 'directory'), 42 | ) 43 | def test_train_and_evaluate(self, dataset_source): 44 | config = common.get_config() 45 | config.model = models.get_testing_config() 46 | config.batch = 64 47 | config.accum_steps = 2 48 | config.batch_eval = 8 49 | config.total_steps = 1 50 | 51 | with tempfile.TemporaryDirectory() as workdir: 52 | if dataset_source == 'tfds': 53 | config.dataset = 'cifar10' 54 | config.pp = ml_collections.ConfigDict({ 55 | 'train': 'train[:98%]', 56 | 'test': 'test', 57 | 'crop': 224 58 | }) 59 | elif dataset_source == 'directory': 60 | config.dataset = os.path.join(workdir, 'dataset') 61 | config.pp = ml_collections.ConfigDict({'crop': 224}) 62 | for mode in ('train', 'test'): 63 | for class_name in ('test1', 'test2'): 64 | for i in range(8): 65 | path = os.path.join(config.dataset, mode, class_name, f'{i}.jpg') 66 | os.makedirs(os.path.dirname(path), exist_ok=True) 67 | with open(path, 'wb') as f: 68 | f.write(JPG_BLACK_1PX) 69 | else: 70 | raise ValueError(f'Unknown dataset_source: "{dataset_source}"') 71 | 72 | config.pretrained_dir = workdir 73 | test_utils.create_checkpoint(config.model, f'{workdir}/testing.npz') 74 | 75 | _ = train.train_and_evaluate(config, workdir) 76 | self.assertTrue(os.path.exists(f'{workdir}/checkpoint_1')) 77 | 78 | 79 | if __name__ == '__main__': 80 | absltest.main() 81 | -------------------------------------------------------------------------------- /contrastive_text_image_learning/vision_transformer/vit_jax/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 Google LLC. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import logging as python_logging 16 | import os 17 | import threading 18 | 19 | from absl import logging 20 | import jax 21 | import jax.numpy as jnp 22 | import tensorflow as tf 23 | 24 | 25 | class GFileHandler(python_logging.StreamHandler): 26 | """Writes log messages to file using tf.io.gfile.""" 27 | 28 | def __init__(self, filename, mode, flush_secs=1.0): 29 | super().__init__() 30 | tf.io.gfile.makedirs(os.path.dirname(filename)) 31 | if mode == 'a' and not tf.io.gfile.exists(filename): 32 | mode = 'w' 33 | self.filehandle = tf.io.gfile.GFile(filename, mode) 34 | self.flush_secs = flush_secs 35 | self.flush_timer = None 36 | 37 | def flush(self): 38 | self.filehandle.flush() 39 | 40 | def emit(self, record): 41 | msg = self.format(record) 42 | self.filehandle.write(f'{msg}\n') 43 | if self.flush_timer is not None: 44 | self.flush_timer.cancel() 45 | self.flush_timer = threading.Timer(self.flush_secs, self.flush) 46 | self.flush_timer.start() 47 | 48 | 49 | def add_gfile_logger(workdir, *, basename='train', level=python_logging.INFO): 50 | """Adds GFile file logger to Python logging handlers.""" 51 | fh = GFileHandler(f'{workdir}/{basename}.log', 'a') 52 | fh.setLevel(level) 53 | fh.setFormatter(logging.PythonFormatter()) 54 | python_logging.getLogger('').addHandler(fh) 55 | 56 | 57 | def create_learning_rate_schedule(total_steps, 58 | base, 59 | decay_type, 60 | warmup_steps, 61 | linear_end=1e-5): 62 | """Creates learning rate schedule. 63 | 64 | Currently only warmup + {linear,cosine} but will be a proper mini-language 65 | like preprocessing one in the future. 66 | 67 | Args: 68 | total_steps: The total number of steps to run. 69 | base: The starting learning-rate (without warmup). 70 | decay_type: 'linear' or 'cosine'. 71 | warmup_steps: how many steps to warm up for. 72 | linear_end: Minimum learning rate. 73 | 74 | Returns: 75 | A function learning_rate(step): float -> {"learning_rate": float}. 76 | """ 77 | 78 | def step_fn(step): 79 | """Step to learning rate function.""" 80 | lr = base 81 | 82 | progress = (step - warmup_steps) / float(total_steps - warmup_steps) 83 | progress = jnp.clip(progress, 0.0, 1.0) 84 | if decay_type == 'linear': 85 | lr = linear_end + (lr - linear_end) * (1.0 - progress) 86 | elif decay_type == 'cosine': 87 | lr = lr * 0.5 * (1. + jnp.cos(jnp.pi * progress)) 88 | else: 89 | raise ValueError(f'Unknown lr type {decay_type}') 90 | 91 | if warmup_steps: 92 | lr = lr * jnp.minimum(1., step / warmup_steps) 93 | 94 | return jnp.asarray(lr, dtype=jnp.float32) 95 | 96 | return step_fn 97 | 98 | 99 | def accumulate_gradient(loss_and_grad_fn, params, images, labels, accum_steps): 100 | """Accumulate gradient over multiple steps to save on memory.""" 101 | if accum_steps and accum_steps > 1: 102 | assert images.shape[0] % accum_steps == 0, ( 103 | f'Bad accum_steps {accum_steps} for batch size {images.shape[0]}') 104 | step_size = images.shape[0] // accum_steps 105 | l, g = loss_and_grad_fn(params, images[:step_size], labels[:step_size]) 106 | 107 | def acc_grad_and_loss(i, l_and_g): 108 | imgs = jax.lax.dynamic_slice(images, (i * step_size, 0, 0, 0), 109 | (step_size,) + images.shape[1:]) 110 | lbls = jax.lax.dynamic_slice(labels, (i * step_size, 0), 111 | (step_size, labels.shape[1])) 112 | li, gi = loss_and_grad_fn(params, imgs, lbls) 113 | l, g = l_and_g 114 | return (l + li, jax.tree_map(lambda x, y: x + y, g, gi)) 115 | 116 | l, g = jax.lax.fori_loop(1, accum_steps, acc_grad_and_loss, (l, g)) 117 | return jax.tree_map(lambda x: x / accum_steps, (l, g)) 118 | else: 119 | return loss_and_grad_fn(params, images, labels) 120 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy==1.17.2 2 | torch==1.2.0 3 | torchvision==0.4.0 4 | -------------------------------------------------------------------------------- /sam_low_rank_summary.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tml-epfl/sam-low-rank-features/0e92a35b7bba64adbae76e56694282fe047d71bb/sam_low_rank_summary.png -------------------------------------------------------------------------------- /two_layer_nets/utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from fc_nets import FCNet2Layers, FCNet, compute_grad_matrix 4 | 5 | 6 | def get_iters_eval(n_iter_power, x_log_scale, n_iters_first=100, n_iters_next=151): 7 | num_iter = int(10**n_iter_power) + 1 8 | 9 | iters_loss_first = np.array(range(n_iters_first)) 10 | if x_log_scale: 11 | iters_loss_next = np.unique(np.round(np.logspace(0, n_iter_power, n_iters_next))) 12 | else: 13 | iters_loss_next = np.unique(np.round(np.linspace(0, num_iter, n_iters_next)))[:-1] 14 | iters_loss = np.unique(np.concatenate((iters_loss_first, iters_loss_next))) 15 | 16 | return num_iter, iters_loss 17 | 18 | 19 | def get_data_two_layer_relu_net(n, d, m_teacher, init_scales_teacher, seed, clsf=False, clsf_mu=0.0, clsf_margin=0.0001, biases=False, act='relu'): 20 | np.random.seed(seed) 21 | torch.manual_seed(seed) 22 | 23 | n_test = 1000 24 | H = np.eye(d) 25 | X = torch.tensor(np.random.multivariate_normal(np.zeros(d), H, n)).float() 26 | if not clsf: 27 | X = X / torch.sum(X**2, 1, keepdim=True)**0.5 28 | X_test = torch.tensor(np.random.multivariate_normal(np.zeros(d), H, n_test)).float() 29 | if not clsf: 30 | X_test = X_test / torch.sum(X_test**2, 1, keepdim=True)**0.5 31 | 32 | # generate ground truth labels 33 | with torch.no_grad(): 34 | net_teacher = FCNet2Layers(n_feature=d, n_hidden=m_teacher, biases=[True, True] if biases else [False, False], act=act) 35 | net_teacher.init_gaussian(init_scales_teacher) 36 | net_teacher.layer1.weight.data = net_teacher.layer1.weight.data / torch.sum((net_teacher.layer1.weight.data)**2, 1, keepdim=True)**0.5 37 | net_teacher.layer2.weight.data = torch.sign(net_teacher.layer2.weight.data) 38 | if clsf: 39 | X[:n//2] -= clsf_mu 40 | X[n//2:] += clsf_mu 41 | X_test[:n//2] -= clsf_mu 42 | X_test[n//2:] += clsf_mu 43 | 44 | y, y_test = net_teacher(X), net_teacher(X_test) 45 | 46 | if clsf: # convert to -1 / 1 47 | # idx_train, idx_test = torch.abs(y.flatten()) > clsf_margin, torch.abs(y_test.flatten()) > clsf_margin 48 | # X, y, X_test, y_test = X[idx_train], y[idx_train], X_test[idx_test], y_test[idx_test] 49 | 50 | y[y < -clsf_margin], y[torch.abs(y) <= clsf_margin], y[y > clsf_margin] = -1, ((torch.randn((y[torch.abs(y) <= clsf_margin]).shape) > 0).float() - 0.5) * 2, 1 51 | y_test[y_test < -clsf_margin], y_test[torch.abs(y_test) <= clsf_margin], y_test[y_test > clsf_margin] = -1, ((torch.randn((y_test[torch.abs(y_test) <= clsf_margin]).shape) > 0).float() - 0.5) * 2, 1 52 | 53 | # y, y_test = ((y > 0).float() - 0.5) * 2, ((y_test > 0).float() - 0.5) * 2 54 | print('y', y[:20, 0]) 55 | 56 | return X, y, X_test, y_test, net_teacher 57 | 58 | 59 | def get_data_multi_layer_relu_net(n, d, m_teacher, init_scales_teacher, seed): 60 | np.random.seed(seed + 1) 61 | torch.manual_seed(seed + 1) 62 | 63 | n_test = 1000 64 | H = np.eye(d) 65 | X = torch.tensor(np.random.multivariate_normal(np.zeros(d), H, n)).float() 66 | X = X / torch.sum(X**2, 1, keepdim=True)**0.5 67 | X_test = torch.tensor(np.random.multivariate_normal(np.zeros(d), H, n_test)).float() 68 | X_test = X_test / torch.sum(X_test**2, 1, keepdim=True)**0.5 69 | 70 | # generate ground truth labels 71 | with torch.no_grad(): 72 | net_teacher = FCNet(n_feature=d, n_hidden=m_teacher) 73 | net_teacher.init_gaussian(init_scales_teacher) 74 | y, y_test = net_teacher(X), net_teacher(X_test) 75 | print('y:', y[:, 0]) 76 | 77 | return X, y, X_test, y_test, net_teacher 78 | 79 | 80 | def effective_rank(v): 81 | v = v[v != 0] 82 | v /= v.sum() 83 | return -(v * np.log(v)).sum() 84 | 85 | 86 | def rm_too_correlated(net, X, V, corr_threshold=0.99): 87 | V = V.T 88 | idx_keep = np.where((V > 0.0).sum(0) > 0)[0] 89 | V_filtered = V[:, idx_keep] # filter out zeros 90 | corr_matrix = np.corrcoef(V_filtered.T) 91 | corr_matrix -= np.eye(corr_matrix.shape[0]) 92 | 93 | idx_to_delete, i, j = [], 0, 0 94 | while i != corr_matrix.shape[0]: 95 | if (np.abs(corr_matrix[i]) > corr_threshold).sum() > 0: 96 | corr_matrix = np.delete(corr_matrix, (i), axis=0) 97 | corr_matrix = np.delete(corr_matrix, (i), axis=1) 98 | # print('delete', j) 99 | idx_to_delete.append(j) 100 | else: 101 | i += 1 102 | j += 1 103 | assert corr_matrix.shape[0] == corr_matrix.shape[1] 104 | idx_keep = np.delete(idx_keep, [idx_to_delete]) 105 | 106 | return V[:, idx_keep].T 107 | 108 | def compute_grad_matrix_dim(net, X, corr_threshold=0.99): 109 | grad_matrix = compute_grad_matrix(net, X) 110 | grad_matrix_sq_norms = np.sum(grad_matrix**2, 0) 111 | m = 100 112 | v_j = [] 113 | for j in range(m): 114 | v_j.append(grad_matrix_sq_norms[[j, m+j, 2*m+j]]) # matrix: w1, w2, w3, w4 115 | V = np.vstack(v_j) 116 | 117 | V_reduced = rm_too_correlated(net, X, V, corr_threshold=corr_threshold) 118 | grad_matrix_dim = V_reduced.shape[0] 119 | return grad_matrix_dim 120 | 121 | def compute_hessian(net, X, y): 122 | def loss_function(*all_params): 123 | w, bw, v, bv = all_params 124 | loss_f = lambda y_pred, y: torch.mean((y_pred - y)**2) 125 | y_pred = F.relu(X @ w.T + bw) @ v.T + bv 126 | loss = loss_f(y_pred, y) 127 | return loss 128 | 129 | h = torch.autograd.functional.hessian(loss_function, tuple(p for p in net.parameters())) 130 | # TODO: unfinished; the Hessian function returns a list of matrices, but we need to compose a single matrix out of them 131 | 132 | --------------------------------------------------------------------------------