├── imgs └── ConvMixer-ViT.png ├── requirements.txt ├── convmixer.sh ├── vit_pex.sh ├── LICENSE ├── util.py ├── README.md ├── mytrain_convmixer.py ├── mytrain_vit.py ├── convmixer.py └── vit.py /imgs/ConvMixer-ViT.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/osiriszjq/impulse_init/HEAD/imgs/ConvMixer-ViT.png -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy 2 | matplotlib 3 | torch==1.13.1 4 | torchvision==0.14.1 5 | einops 6 | wandb -------------------------------------------------------------------------------- /convmixer.sh: -------------------------------------------------------------------------------- 1 | data_path='./data' 2 | for lr in '1e-3'; do 3 | for dataset in 'cifar10' 'cifar100'; do 4 | for heads in '512'; do 5 | for init in 'random' 'softmax' 'box1' 'box25'; do 6 | 7 | # for i in 1 2 3 4 5; do 8 | python mytrain_convmixer.py --dataset ${dataset} --data_path ${data_path} --init ${init} --heads ${heads} --lr ${lr} 9 | python mytrain_convmixer.py --dataset ${dataset} --data_path ${data_path} --init ${init} --heads ${heads} --lr ${lr} --fix_spatial 10 | # done 11 | 12 | done 13 | done 14 | done 15 | done -------------------------------------------------------------------------------- /vit_pex.sh: -------------------------------------------------------------------------------- 1 | data_path='./data' 2 | for dataset in 'cifar10' 'cifar100' 'svhn' 'tiny_imagenet'; do 3 | for lr in '1e-4'; do 4 | for init in 'impulse16_64_5_0.1_100' 'mimetic512_64' 'random512_64'; do # 'impulse16_64_5_0.1_100' 'mimetic512_64' 'random512_64' 5 | # for i in 1 2 3 4 5; do 6 | for alpha in '0.0' '0.1' '0.2' '0.3' '0.4' '0.5'; do 7 | python mytrain_vit.py --dataset ${dataset} --data_path ${data_path} --lr ${lr} --spatial_pe --spatial_x --init ${init} --use_value --trainable --alpha ${alpha} --data_aug 8 | python mytrain_vit.py --dataset ${dataset} --data_path ${data_path} --lr ${lr} --spatial_pe --spatial_x --init ${init} --use_value --trainable --alpha ${alpha} 9 | done 10 | # done 11 | 12 | done 13 | done 14 | done -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Jianqiao Zheng 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /util.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import os 3 | 4 | # my srank 5 | def srank_l12(X): 6 | (u,s,v) = torch.svd(X) 7 | sr2 = (s*s).sum()/s[0]/s[0] 8 | sr1 = s.sum()/s[0] 9 | return sr1,sr2 10 | 11 | 12 | # my counting parameters 13 | def count_parameters(model): 14 | return sum(p.numel() for p in model.parameters() if p.requires_grad) 15 | 16 | 17 | # for tiny imagenet 200 18 | def create_val_img_folder(args): 19 | ''' 20 | This method is responsible for separating validation images into separate sub folders 21 | ''' 22 | dataset_dir = os.path.join(args.data_path, 'tiny-imagenet-200') 23 | val_dir = os.path.join(dataset_dir, 'val') 24 | img_dir = os.path.join(val_dir, 'images') 25 | 26 | fp = open(os.path.join(val_dir, 'val_annotations.txt'), 'r') 27 | data = fp.readlines() 28 | val_img_dict = {} 29 | for line in data: 30 | words = line.split('\t') 31 | val_img_dict[words[0]] = words[1] 32 | fp.close() 33 | 34 | # Create folder if not present and move images into proper folders 35 | for img, folder in val_img_dict.items(): 36 | newpath = (os.path.join(img_dir, folder)) 37 | if not os.path.exists(newpath): 38 | os.makedirs(newpath) 39 | if os.path.exists(os.path.join(img_dir, img)): 40 | os.rename(os.path.join(img_dir, img), os.path.join(newpath, img)) 41 | 42 | 43 | def get_class_name(args): 44 | class_to_name = dict() 45 | fp = open(os.path.join(args.data_dir, args.dataset, 'words.txt'), 'r') 46 | data = fp.readlines() 47 | for line in data: 48 | words = line.strip('\n').split('\t') 49 | class_to_name[words[0]] = words[1].split(',')[0] 50 | fp.close() 51 | return class_to_name -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Convolutional Initialization for Data-Efficient Vision Transformers 2 | ### [Project Page](https://osiriszjq.github.io/impulse_init) | [Paper](https://arxiv.org/pdf/2401.12511.pdf) 3 | [![License: MIT](https://img.shields.io/badge/License-MIT-yellow.svg)](https://opensource.org/licenses/MIT) 4 | 5 | 6 | [Jianqiao Zheng](https://github.com/osiriszjq/), 7 | [Xueqian Li](https://lilac-lee.github.io/), 8 | [Simon Lucey](https://www.adelaide.edu.au/directory/simon.lucey)
9 | The University of Adelaide 10 | 11 | --- 12 | 🚀 **New! Explore our updated NeurIPS'25 work [Structured Initialization for Vision Transformers](https://github.com/osiriszjq/structured_initialization) — Released in Dec 2025** 13 | 14 | --- 15 | 16 | This is the official implementation of the paper "Convolutional Initialization for Data-Efficient Vision Transformers", including a modified version of [ConvMixer](https://arxiv.org/abs/2201.09792) and [Simple ViT](https://arxiv.org/abs/2205.01580) on CIFAR-10, CIFAR-100, SVHN and [Tiny ImageNet](http://vision.stanford.edu/teaching/cs231n/reports/2015/pdfs/yle_project.pdf). The code is based on [vision-transformers-cifar10](https://github.com/kentaroy47/vision-transformers-cifar10/tree/main) 17 | 18 | #### Illustration of different methods to extend 1D encoding 19 | ![Illustration of different methods to extend 1D encoding](imgs/ConvMixer-ViT.png) 20 | 21 | 22 | ## Google Colab 23 | [![Explore Siren in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://github.com/osiriszjq/impulse_init/blob/main/Impulse_Initialization.ipynb)
24 | If you want to try out our new initialization for ViT, check this [Colab](https://github.com/osiriszjq/impulse_init/blob/main/Impulse_Initialization.ipynb) for a quick tour. 25 | 26 | 27 | ## Usage 28 | Modify `convmixer.sh` or `vit_pex.sh` first to change the data path and what experiments you want to run, and then just run 29 | ``` 30 | bash convmixer.sh 31 | ``` 32 | or 33 | ``` 34 | bash vit_pex.sh 35 | ``` 36 | 37 | 38 | ## Citation 39 | ``` 40 | @article{zheng2024convolutional, 41 | title={Convolutional Initialization for Data-Efficient Vision Transformers}, 42 | author={Zheng, Jianqiao and Li, Xueqian and Lucey, Simon}, 43 | journal={arXiv preprint arXiv:2401.12511}, 44 | year={2024} 45 | } 46 | ``` 47 | -------------------------------------------------------------------------------- /mytrain_convmixer.py: -------------------------------------------------------------------------------- 1 | # https://github.com/kentaroy47/vision-transformers-cifar10/blob/main/train_cifar10.py 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.optim as optim 6 | import torch.backends.cudnn as cudnn 7 | 8 | import torchvision 9 | import torchvision.transforms as transforms 10 | 11 | import time 12 | import argparse 13 | from util import * 14 | from convmixer import * 15 | 16 | 17 | parser = argparse.ArgumentParser() 18 | 19 | parser.add_argument("--dataset", type=str, default="cifar10") 20 | parser.add_argument("--data_path", type=str, default="./data") 21 | parser.add_argument('--nowandb', action='store_true', help='disable wandb') 22 | 23 | parser.add_argument('--opt', default="adam") 24 | parser.add_argument('--scheduler', default="cos") 25 | parser.add_argument('--lr', default=1e-3, type=float, help='learning rate') 26 | parser.add_argument('--wd', default=0.0, type=float) 27 | parser.add_argument('--epochs', default=200, type=int) 28 | parser.add_argument('--batch-size', default=512, type=int) 29 | parser.add_argument('--workers', default=16, type=int) 30 | parser.add_argument('--data_aug', action='store_true') 31 | parser.add_argument('--noamp', action='store_true', help='disable mixed precision training. for older pytorch versions') 32 | 33 | parser.add_argument('--dim', default=512, type=int) 34 | parser.add_argument('--heads', default=0, type=int, help='number of different filters in one layer') 35 | parser.add_argument('--depth', default=6, type=int) 36 | parser.add_argument('--psize', default=2, type=int) 37 | parser.add_argument('--conv-ks', default=5, type=int) 38 | 39 | parser.add_argument('--fix_spatial', action='store_true', help='freeze spatial mixing') 40 | parser.add_argument("--init", type=str, default="random") 41 | parser.add_argument('--input_weight', action='store_true', help='share weights in different layers') 42 | parser.add_argument("--linear_format", action='store_true', help='use linear format conv filters') 43 | parser.add_argument("--no_spatial_bias", action='store_true', help='disable bias for spatial conv') 44 | 45 | 46 | args = parser.parse_args() 47 | args.spatial = not args.fix_spatial 48 | args.spatial_bias = not args.no_spatial_bias 49 | use_amp = not args.noamp 50 | print(args) 51 | 52 | 53 | usewandb = not args.nowandb 54 | if usewandb: 55 | import wandb 56 | watermark = "{}_h{}".format(args.init,args.heads) 57 | wandb.init(project=f'convmixer-{args.dataset}',name=watermark) 58 | wandb.config.update(args) 59 | 60 | 61 | 62 | print(f'==> Preparing {args.dataset} data..') 63 | dataset_mean = (0.4914, 0.4822, 0.4465) 64 | dataset_std = (0.2471, 0.2435, 0.2616) 65 | if args.data_aug: 66 | train_transform = transforms.Compose([ 67 | transforms.RandAugment(2, 14), 68 | transforms.RandomCrop(32, scale=(1.0,1.0),ratio=(1.0,1.0)), 69 | transforms.RandomHorizontalFlip(), 70 | transforms.ToTensor(), 71 | transforms.Normalize(dataset_mean, dataset_std) 72 | ]) 73 | else: 74 | train_transform = transforms.Compose([ 75 | transforms.ToTensor(), 76 | transforms.Normalize(dataset_mean, dataset_std) 77 | ]) 78 | 79 | test_transform = transforms.Compose([ 80 | transforms.ToTensor(), 81 | transforms.Normalize(dataset_mean, dataset_std) 82 | ]) 83 | if args.dataset == 'cifar10': 84 | n_class = 10 85 | image_size = 32 86 | trainset = torchvision.datasets.CIFAR10(root=args.data_path, train=True, download=True, transform=train_transform) 87 | testset = torchvision.datasets.CIFAR10(root=args.data_path, train=False, download=True, transform=test_transform) 88 | elif args.dataset == 'cifar100': 89 | n_class = 100 90 | image_size = 32 91 | trainset = torchvision.datasets.CIFAR100(root=args.data_path, train=True, download=True, transform=train_transform) 92 | testset = torchvision.datasets.CIFAR100(root=args.data_path, train=False, download=True, transform=test_transform) 93 | trainloader = torch.utils.data.DataLoader(trainset, batch_size=args.batch_size, shuffle=True, num_workers=args.workers) 94 | testloader = torch.utils.data.DataLoader(testset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers) 95 | 96 | 97 | 98 | print('==> Building model..') 99 | device = 'cuda' if torch.cuda.is_available() else 'cpu' 100 | model = ConvMixer(args.dim, args.depth, patch_size=args.psize, kernel_size=args.conv_ks, n_classes=n_class, image_size=image_size, return_embedding=False, 101 | init=args.init, heads=args.heads, spatial=args.spatial, spatial_bias=args.spatial_bias, input_weight=args.input_weight,linear_format=args.linear_format) 102 | if 'cuda' in device: 103 | print(device) 104 | print("using data parallel") 105 | model = torch.nn.DataParallel(model).cuda() 106 | cudnn.benchmark = True 107 | criterion = nn.CrossEntropyLoss() 108 | 109 | if args.opt == "adam": 110 | optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.wd) 111 | elif args.opt == "adamw": 112 | optimizer = optim.AdamW(model.parameters(), lr=args.lr, weight_decay=args.wd) 113 | 114 | if args.scheduler == "cos": 115 | scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, args.epochs) 116 | scaler = torch.cuda.amp.GradScaler(enabled=use_amp) 117 | 118 | num_param = count_parameters(model) 119 | for epoch in range(args.epochs): 120 | start = time.time() 121 | train_loss, train_acc, n = 0, 0, 0 122 | n_batch = 0 123 | for i, (X, y) in enumerate(trainloader): 124 | model.train() 125 | X, y = X.cuda(), y.cuda() 126 | 127 | with torch.cuda.amp.autocast(enabled=use_amp): 128 | output = model(X) 129 | loss = criterion(output, y) 130 | scaler.scale(loss).backward() 131 | scaler.step(optimizer) 132 | scaler.update() 133 | optimizer.zero_grad() 134 | 135 | train_loss += loss.item() * y.size(0) 136 | train_acc += (output.max(1)[1] == y).sum().item() 137 | n += y.size(0) 138 | train_acc = train_acc/n 139 | train_loss = train_loss/n 140 | 141 | 142 | model.eval() 143 | test_acc, m = 0, 0 144 | with torch.no_grad(): 145 | for i, (X, y) in enumerate(testloader): 146 | X, y = X.cuda(), y.cuda() 147 | with torch.cuda.amp.autocast(): 148 | output = model(X) 149 | test_acc += (output.max(1)[1] == y).sum().item() 150 | m += y.size(0) 151 | test_acc = test_acc/m 152 | scheduler.step() 153 | 154 | if usewandb: 155 | wandb.log({'epoch': epoch, 'train_loss': train_loss, 'train_acc': train_acc, "val_acc": test_acc, "lr": optimizer.param_groups[0]["lr"], 156 | "epoch_time": time.time()-start, 'num_param':num_param}) 157 | else: 158 | print(f'epoch: {epoch}, train_loss: {train_loss:.4f}, train_acc: {train_acc:.4f}, val_acc: {test_acc:.4f}, lr: {optimizer.param_groups[0]["lr"]:.6f}, epoch_time: {time.time()-start:.1f}, num_param:{num_param}') -------------------------------------------------------------------------------- /mytrain_vit.py: -------------------------------------------------------------------------------- 1 | # https://github.com/kentaroy47/vision-transformers-cifar10/blob/main/train_cifar10.py 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.optim as optim 6 | import torch.backends.cudnn as cudnn 7 | 8 | import torchvision 9 | import torchvision.transforms as transforms 10 | 11 | import time 12 | import argparse 13 | from util import * 14 | from vit import * 15 | 16 | 17 | parser = argparse.ArgumentParser() 18 | 19 | parser.add_argument("--dataset", type=str, default="cifar10") 20 | parser.add_argument("--data_path", type=str, default="./data") 21 | parser.add_argument('--nowandb', action='store_true', help='disable wandb') 22 | 23 | parser.add_argument('--opt', default="adam") 24 | parser.add_argument('--scheduler', default="cos") 25 | parser.add_argument('--lr', default=1e-4, type=float, help='learning rate') 26 | parser.add_argument('--wd', default=0.0, type=float) 27 | parser.add_argument('--epochs', default=200, type=int) 28 | parser.add_argument('--batch-size', default=512, type=int) 29 | parser.add_argument('--num_workers', default=16, type=int) 30 | parser.add_argument('--data_aug', action='store_true') 31 | parser.add_argument('--noamp', action='store_true', help='disable mixed precision training. for older pytorch versions') 32 | 33 | parser.add_argument('--dim', default=512, type=int) 34 | parser.add_argument('--heads', default=8, type=int) 35 | parser.add_argument('--depth', default=6, type=int) 36 | parser.add_argument('--psize', default=2, type=int) 37 | parser.add_argument('--mlp_dim', default=512, type=int) 38 | parser.add_argument('--dim_head', default=64, type=int) 39 | 40 | parser.add_argument('--input_pe', action='store_true', help='use pe at input') 41 | parser.add_argument("--init", type=str, default="none") 42 | parser.add_argument("--pe_choice", type=str, default="sin") 43 | parser.add_argument('--use_value', action='store_true', help='use value') 44 | parser.add_argument("--spatial_pe", action='store_true', help='use pe for spatial mixing') 45 | parser.add_argument("--spatial_x", action='store_true', help='use x for spatial mixing') 46 | parser.add_argument('--alpha', default=0.5, type=float, help='balance pe and x') 47 | parser.add_argument("--trainable", action='store_true', help='let spatial mixing to be trainable') 48 | 49 | 50 | args = parser.parse_args() 51 | use_amp = not args.noamp 52 | print(args) 53 | 54 | 55 | usewandb = not args.nowandb 56 | if usewandb: 57 | import wandb 58 | watermark = "{}_h{}".format(args.init,args.heads) 59 | wandb.init(project=f'vit-{args.dataset}', name=watermark) 60 | wandb.config.update(args) 61 | 62 | 63 | print(f'==> Preparing {args.dataset} data..') 64 | if args.dataset[:5] == 'cifar': 65 | image_size = 32 66 | dataset_mean = (0.4914, 0.4822, 0.4465) 67 | dataset_std = (0.2471, 0.2435, 0.2616) 68 | elif args.dataset == 'svhn': 69 | image_size = 32 70 | dataset_mean = (0.4376821, 0.4437697, 0.47280442) 71 | dataset_std = (0.19803012, 0.20101562, 0.19703614) 72 | elif args.dataset == 'tiny_imagenet': 73 | image_size = 64 74 | args.psize = 4 75 | print(args) 76 | dataset_mean = (0.485, 0.456, 0.406) 77 | dataset_std = (0.229, 0.224, 0.225) 78 | else: 79 | print('no available dataset') 80 | if args.data_aug: 81 | train_transform = transforms.Compose([ 82 | transforms.RandAugment(2, 14), 83 | transforms.RandomHorizontalFlip(), 84 | transforms.RandomResizedCrop(image_size), 85 | transforms.ToTensor(), 86 | transforms.Normalize(dataset_mean, dataset_std) 87 | ]) 88 | else: 89 | train_transform = transforms.Compose([ 90 | transforms.Resize(image_size), 91 | transforms.ToTensor(), 92 | transforms.Normalize(dataset_mean, dataset_std) 93 | ]) 94 | 95 | test_transform = transforms.Compose([ 96 | transforms.Resize(image_size), 97 | transforms.ToTensor(), 98 | transforms.Normalize(dataset_mean, dataset_std) 99 | ]) 100 | if args.dataset == 'cifar10': 101 | n_class = 10 102 | trainset = torchvision.datasets.CIFAR10(root=args.data_path, train=True, download=True, transform=train_transform) 103 | testset = torchvision.datasets.CIFAR10(root=args.data_path, train=False, download=True, transform=test_transform) 104 | elif args.dataset == 'cifar100': 105 | n_class = 100 106 | trainset = torchvision.datasets.CIFAR100(root=args.data_path, train=True, download=True, transform=train_transform) 107 | testset = torchvision.datasets.CIFAR100(root=args.data_path, train=False, download=True, transform=test_transform) 108 | elif args.dataset == 'svhn': 109 | n_class = 10 110 | trainset = torchvision.datasets.SVHN(root=args.data_path, split='train', download=True, transform=train_transform) 111 | testset = torchvision.datasets.SVHN(root=args.data_path, split='test', download=True, transform=test_transform) 112 | elif args.dataset == 'tiny_imagenet': 113 | trainset = torchvision.datasets.ImageFolder(root=args.data_path+'/tiny-imagenet-200/train', transform=train_transform) 114 | testset = torchvision.datasets.ImageFolder(root=args.data_path+'/tiny-imagenet-200/val/images', transform=test_transform) 115 | trainloader = torch.utils.data.DataLoader(trainset, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers, pin_memory=True) 116 | testloader = torch.utils.data.DataLoader(testset, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers, pin_memory=True) 117 | 118 | 119 | 120 | print('==> Building model..') 121 | device = 'cuda' if torch.cuda.is_available() else 'cpu' 122 | model = SimpleViT(image_size=image_size, patch_size=args.psize, num_classes=n_class, dim=args.dim, depth=args.depth, heads=args.heads, mlp_dim=args.mlp_dim, dim_head=args.dim_head, 123 | input_pe=args.input_pe, pe_choice=args.pe_choice, use_value=args.use_value, spatial_pe=args.spatial_pe, spatial_x=args.spatial_x, init=args.init, alpha=args.alpha, trainable=args.trainable) 124 | if 'cuda' in device: 125 | print(device) 126 | print("using data parallel") 127 | model = torch.nn.DataParallel(model).cuda() 128 | cudnn.benchmark = True 129 | criterion = nn.CrossEntropyLoss() 130 | 131 | if args.opt == "adam": 132 | optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.wd) 133 | elif args.opt == "adamw": 134 | optimizer = optim.AdamW(model.parameters(), lr=args.lr, weight_decay=args.wd) 135 | 136 | if args.scheduler == "cos": 137 | scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, args.epochs) 138 | scaler = torch.cuda.amp.GradScaler(enabled=use_amp) 139 | 140 | num_param = count_parameters(model) 141 | for epoch in range(args.epochs): 142 | start = time.time() 143 | train_loss, train_acc, n = 0, 0, 0 144 | for i, (X, y) in enumerate(trainloader): 145 | model.train() 146 | X, y = X.cuda(), y.cuda() 147 | with torch.cuda.amp.autocast(enabled=use_amp): 148 | output = model(X) 149 | loss = criterion(output, y) 150 | scaler.scale(loss).backward() 151 | scaler.step(optimizer) 152 | scaler.update() 153 | optimizer.zero_grad() 154 | 155 | train_loss += loss.item() * y.size(0) 156 | train_acc += (output.max(1)[1] == y).sum().item() 157 | n += y.size(0) 158 | train_acc = train_acc/n 159 | train_loss = train_loss/n 160 | 161 | 162 | model.eval() 163 | test_acc, m = 0, 0 164 | with torch.no_grad(): 165 | for i, (X, y) in enumerate(testloader): 166 | X, y = X.cuda(), y.cuda() 167 | with torch.cuda.amp.autocast(): 168 | output = model(X) 169 | test_acc += (output.max(1)[1] == y).sum().item() 170 | m += y.size(0) 171 | test_acc = test_acc/m 172 | scheduler.step() 173 | 174 | if usewandb: 175 | wandb.log({'epoch': epoch, 'train_loss': train_loss, 'train_acc': train_acc, "val_acc": test_acc, "lr": optimizer.param_groups[0]["lr"], 176 | "epoch_time": time.time()-start, 'num_param':num_param}) 177 | else: 178 | print(f'epoch: {epoch}, train_loss: {train_loss:.4f}, train_acc: {train_acc:.4f}, val_acc: {test_acc:.4f}, lr: {optimizer.param_groups[0]["lr"]:.6f}, epoch_time: {time.time()-start:.1f}, num_param:{num_param}') -------------------------------------------------------------------------------- /convmixer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import numpy as np 4 | 5 | 6 | # initialization of convolusion kernel, C*1*K*K 7 | def SpatialConv2d_init(C, kernel_size, init='random'): 8 | if (init == 'random')|(init == 'softmax'): 9 | weight = 1/kernel_size*(2*torch.rand((C,1,kernel_size,kernel_size))-1) 10 | elif init == 'impulse': 11 | k = torch.randint(0,kernel_size*kernel_size,(C,1)) 12 | weight = torch.zeros((C,1,kernel_size*kernel_size)) 13 | for i in range(C): 14 | for j in range(1): 15 | weight[i,j,k[i,j]] = 1 16 | weight = np.sqrt(1/3)*weight.reshape(C,1,kernel_size,kernel_size) 17 | elif init[:3] == 'box': 18 | weight = torch.zeros((C,1,kernel_size*kernel_size)) 19 | for i in range(C): 20 | for j in range(1): 21 | k = np.random.choice(kernel_size*kernel_size,int(init[3:]),replace=False) 22 | weight[i,j,k] = 1 23 | weight = np.sqrt(1/int(init[3:])/3)*weight.reshape(C,1,kernel_size,kernel_size) 24 | elif init[:3] == 'gau': 25 | k = torch.randint(0,kernel_size,(C,1,2)) 26 | weight = torch.zeros((C,1,kernel_size,kernel_size)) 27 | for i in range(C): 28 | for j in range(1): 29 | for p in range(kernel_size): 30 | for q in range(kernel_size): 31 | weight[i,j,p,q] = (-0.5/float(init[3:])*((p-k[i,j,0])**2+(q-k[i,j,1])**2)).exp() 32 | weight = weight/((weight.flatten(1,3)**2).sum(1).mean()*3).sqrt() 33 | else: 34 | return -1 35 | return weight 36 | 37 | 38 | # initialization of convolusion kernel in linear format, C*img_size*img_size, Out of Memory!!! 39 | def SpatialConv2d_Linear_init(C, H, W, init='perm'): 40 | weight = torch.zeros((C,H*W,H*W)) 41 | if init == 'perm': 42 | for i in range(C): 43 | k = torch.randint(0,H*W,(H*W,)) 44 | for j in range(H*W): 45 | weight[i][j,k[j]] = 1 46 | weight = np.sqrt(1/3)*weight 47 | elif init == 'fullperm': 48 | for i in range(C): 49 | k = torch.randperm(H*W) 50 | for j in range(H*W): 51 | weight[i][j,k[j]] = 1 52 | weight = np.sqrt(1/3)*weight 53 | elif init[:6] == 'impulse': 54 | ff = int(init[7:]) 55 | weight = torch.zeros((C,H*W,H*W)) 56 | k = torch.randint(0,ff**2,(C,)) 57 | for i in range(C): 58 | m = (k[i]//ff)-(ff//2) 59 | n = (k[i]%ff)-(ff//2) 60 | tmp_weight = torch.zeros((W,W)) 61 | for j in range(0-min(0,n),W-max(0,n)): 62 | tmp_weight[j,j+n] = 1 63 | for j in range(0-min(0,m),H-max(0,m)): 64 | weight[i,j*W:(j+1)*W,(j+m)*W:(j+m+1)*W] = tmp_weight 65 | weight = np.sqrt(1/3)*weight 66 | else: 67 | return -1 68 | return weight 69 | 70 | 71 | 72 | # my spatial conv fuction, group=#channels, heads controls the number of different conv filters 73 | class SpatialConv2d(nn.Module): 74 | def __init__(self, C, kernel_size, bias=True, init='random', heads = -1, trainable= True, input_weight=None): 75 | super(SpatialConv2d, self).__init__() 76 | self.C = C 77 | self.kernel_size = kernel_size 78 | self.init = init 79 | 80 | # different initialisation 81 | weight = SpatialConv2d_init(C,kernel_size,init=init) 82 | 83 | # how many heads or different filters we want to use 84 | if (heads<1)|(heads>C) : 85 | heads = C 86 | self.choice_idx = np.random.choice(heads,C,replace=(headsC) : 124 | heads = C 125 | self.choice_idx = np.random.choice(heads,C,replace=(headsdim) : 175 | heads = dim 176 | else: 177 | self.input_weight = None 178 | 179 | # choose spatial conv format 180 | if linear_format: 181 | H=int(image_size/patch_size) 182 | W=int(image_size/patch_size) 183 | if input_weight: 184 | self.input_weight = SpatialConv2d_Linear_init(dim, H, W, init=init) 185 | self.input_weight = nn.Parameter(self.input_weight[:heads],requires_grad=spatial) 186 | for _ in range(depth): 187 | self.mixer.append(nn.ModuleList([ 188 | Residual(nn.Sequential( 189 | SpatialConv2d_LinearFormat(dim, H, W, bias=spatial_bias, init=init, heads= heads, trainable=spatial, input_weight=self.input_weight), 190 | nn.GELU(), nn.BatchNorm2d(dim))), 191 | nn.Sequential(nn.Conv2d(dim, dim, kernel_size=1), nn.GELU(), nn.BatchNorm2d(dim)) 192 | ])) 193 | else: 194 | if input_weight: 195 | self.input_weight = SpatialConv2d_init(dim, kernel_size, init=init) 196 | self.input_weight = nn.Parameter(self.input_weight[:heads],requires_grad=spatial) 197 | for _ in range(depth): 198 | self.mixer.append(nn.ModuleList([ 199 | Residual(nn.Sequential( 200 | SpatialConv2d(dim, kernel_size, bias=spatial_bias, init=init, heads= heads, trainable=spatial, input_weight=self.input_weight), 201 | nn.BatchNorm2d(dim))), 202 | # missing GeLU here !!!! 203 | nn.Sequential(nn.Conv2d(dim, dim, kernel_size=1),nn.GELU(),nn.BatchNorm2d(dim)) 204 | ])) 205 | 206 | self.output = nn.Sequential(nn.AdaptiveAvgPool2d((1,1)), 207 | nn.Flatten(), 208 | nn.Linear(dim, n_classes)) 209 | 210 | def forward(self, x): 211 | if self.return_embeding: xs = [x] 212 | x = self.input_embedding(x) 213 | if self.return_embeding: xs.append(x) 214 | for spatial,channel in self.mixer: 215 | x = spatial(x) 216 | x = channel(x) 217 | if self.return_embeding: xs.append(x) 218 | x = self.output(x) 219 | if self.return_embeding: 220 | return x, xs 221 | else: 222 | return x -------------------------------------------------------------------------------- /vit.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import torch.optim as optim 4 | 5 | import numpy as np 6 | from einops import rearrange 7 | from einops.layers.torch import Rearrange 8 | 9 | # helpers 10 | 11 | def pair(t): 12 | return t if isinstance(t, tuple) else (t, t) 13 | 14 | def posemb_sincos_2d(h, w, dim, temperature: int = 10000, dtype = torch.float32): 15 | y, x = torch.meshgrid(torch.arange(h), torch.arange(w), indexing="ij") 16 | assert (dim % 4) == 0, "feature dimension must be multiple of 4 for sincos emb" 17 | omega = torch.arange(dim // 4) / (dim // 4 - 1) 18 | omega = 1.0 / (temperature ** omega) 19 | 20 | y = y.flatten()[:, None] * omega[None, :] 21 | x = x.flatten()[:, None] * omega[None, :] 22 | pe = torch.cat((x.sin(), x.cos(), y.sin(), y.cos()), dim=1) 23 | return pe.type(dtype) 24 | 25 | 26 | # my impulse initilization function 27 | def impulse_init(heads,img_size,att_rank,ff,scale=1.0,spatial_pe=None,norm=1): 28 | weight = torch.zeros((heads,img_size**2,img_size**2)) 29 | k = torch.randint(0,ff**2,(heads,)) 30 | for i in range(heads): 31 | m = (k[i]//ff)-(ff//2) 32 | n = (k[i]%ff)-(ff//2) 33 | tmp_weight = torch.zeros((img_size,img_size)) 34 | for j in range(0-min(0,n),img_size-max(0,n)): 35 | tmp_weight[j,j+n] = 1 36 | for j in range(0-min(0,m),img_size-max(0,m)): 37 | weight[i,j*img_size:(j+1)*img_size,(j+m)*img_size:(j+m+1)*img_size] = tmp_weight 38 | # weight = np.sqrt(1/3)*weight 39 | class PermuteM(nn.Module): 40 | def __init__(self, heads, img_size, att_rank,scale=1.0,spatial_pe=None): 41 | super().__init__() 42 | self.scale = scale 43 | if spatial_pe is None: 44 | self.spatial_pe = False 45 | weights_Q = np.sqrt(1/att_rank/heads)*(2*torch.rand(heads,img_size,att_rank)-1) 46 | weights_K = np.sqrt(1/att_rank/heads)*(2*torch.rand(heads,att_rank,img_size)-1) 47 | else: 48 | self.spatial_pe = True 49 | self.pe = spatial_pe.cuda() 50 | weights_Q = np.sqrt(1/att_rank/heads)*(2*torch.rand(heads,spatial_pe.shape[1],att_rank)-1) 51 | weights_K = np.sqrt(1/att_rank/heads)*(2*torch.rand(heads, att_rank, spatial_pe.shape[1])-1) 52 | 53 | self.weights_K = nn.Parameter(weights_K) 54 | self.weights_Q = nn.Parameter(weights_Q) 55 | def forward(self): 56 | if self.spatial_pe: 57 | M = self.pe@self.weights_Q@self.weights_K@(self.pe.T) 58 | else: 59 | M = torch.bmm(self.weights_Q,self.weights_K) 60 | return torch.softmax(M*self.scale,-1) 61 | 62 | net = PermuteM(heads,img_size**2,att_rank,scale,spatial_pe) 63 | net.cuda() 64 | 65 | nq = net.weights_Q.detach().cpu().norm(dim=(1)).mean() 66 | weight = weight.cuda() 67 | num_epoch = 10000 68 | criterion = nn.MSELoss() 69 | optimizer = optim.Adam(net.parameters(), lr=0.0001)#,weight_decay=1e-6) 70 | for i in range(num_epoch): 71 | if i%norm==0: 72 | with torch.no_grad(): 73 | net.weights_Q.div_(net.weights_Q.detach().norm(dim=(1),keepdim=True)/nq) 74 | net.weights_K.div_(net.weights_K.detach().norm(dim=(1),keepdim=True)/nq) 75 | optimizer.zero_grad() 76 | outputs = net() 77 | loss = criterion(outputs, weight) 78 | loss.backward() 79 | optimizer.step() 80 | print(loss.data) 81 | 82 | return net.weights_Q.detach().cpu(),net.weights_K.detach().cpu() 83 | 84 | # classes 85 | 86 | class FeedForward(nn.Module): 87 | def __init__(self, dim, hidden_dim): 88 | super().__init__() 89 | self.net = nn.Sequential( 90 | nn.LayerNorm(dim), 91 | nn.Linear(dim, hidden_dim), 92 | nn.GELU(), 93 | nn.Linear(hidden_dim, dim), 94 | ) 95 | def forward(self, x): 96 | return self.net(x) 97 | 98 | class Attention(nn.Module): 99 | def __init__(self, dim, heads = 8, dim_head = 64, use_value = True, spatial_pe = None, 100 | spatial_x = True, init = 'none', alpha=1.0, trainable=True, out_layer=True): 101 | super().__init__() 102 | inner_dim = dim_head * heads 103 | self.scale = dim_head ** -0.5 104 | self.heads = heads 105 | self.norm = nn.LayerNorm(dim) 106 | self.attend = nn.Softmax(dim = -1) 107 | 108 | self.alpha = alpha 109 | 110 | # input to q&k 111 | self.spatial_x = spatial_x 112 | self.spatial_pe = False 113 | if spatial_pe is not None: 114 | self.spatial_pe = True 115 | self.pos_embedding = spatial_pe 116 | 117 | # format & initilization of q&k 118 | self.init = init 119 | if init == 'none': 120 | self.to_qk = nn.Linear(dim, inner_dim*2, bias = False) 121 | else: 122 | if init[:7] == 'impulse': 123 | a, b, c, d, e = init[7:].split('_') 124 | img_size = int(a) 125 | att_rank = int(b) 126 | ff = int(c) 127 | self.scale = float(d) 128 | norm = int(e) 129 | Q, K = impulse_init(heads,img_size,att_rank,ff,self.scale,spatial_pe,norm) 130 | elif init[:6] == 'random': 131 | a, b = init[6:].split('_') 132 | img_size = int(a) 133 | att_rank = int(b) 134 | Q = np.sqrt(1/img_size)*(2*torch.rand(heads,img_size,att_rank)-1) 135 | K = np.sqrt(1/img_size)*(2*torch.rand(heads,att_rank,img_size)-1) 136 | elif init[:7] == 'mimetic': 137 | a, b = init[7:].split('_') 138 | img_size = int(a) 139 | att_rank = int(b) 140 | W = 0.7*np.sqrt(1/img_size)*(2*torch.rand(heads,img_size,img_size)-1)+0.7*torch.eye(img_size).unsqueeze(0).repeat(heads,1,1) 141 | U,s,V = torch.linalg.svd(W) 142 | s_2 = torch.sqrt(s) 143 | Q = torch.matmul(U[:,:,:att_rank], torch.diag_embed(s_2)[:,:att_rank,:att_rank]) 144 | K = torch.matmul(torch.diag_embed(s_2)[:,:att_rank,:att_rank], V[:,:att_rank,:]) 145 | if self.spatial_pe|self.spatial_x: 146 | print('use linear format') 147 | self.to_qk = nn.Linear(dim, inner_dim*2, bias = False) 148 | self.to_qk.weight.data[:512,:] = rearrange(Q, 'h n d -> n (h d)').T 149 | self.to_qk.weight.data[512:,:] = rearrange(K, 'h d n -> n (h d)').T 150 | else: 151 | print('use Q K format') 152 | self.Q = nn.Parameter(Q,requires_grad=trainable) 153 | self.K = nn.Parameter(K,requires_grad=trainable) 154 | 155 | # use v or just use x 156 | self.use_value = use_value 157 | if use_value: 158 | self.to_v = nn.Linear(dim, inner_dim, bias = False) 159 | 160 | # use output layer or not 161 | self.out_layer = out_layer 162 | if self.out_layer: 163 | self.to_out = nn.Linear(inner_dim, dim, bias = False) 164 | 165 | 166 | 167 | def forward(self, x): 168 | x = self.norm(x) 169 | 170 | # use v or just use x 171 | if self.use_value: 172 | v = self.to_v(x) 173 | else: 174 | v = x 175 | 176 | # q&k format 177 | if self.spatial_pe|self.spatial_x: 178 | # input to q&v 179 | device = x.device 180 | if self.spatial_pe&self.spatial_x: 181 | x = self.alpha*x + (1-self.alpha)*self.pos_embedding.to(device, dtype=x.dtype) 182 | elif self.spatial_pe: 183 | x = 0*x + self.pos_embedding.to(device, dtype=x.dtype) 184 | qk = self.to_qk(x).chunk(2, dim = -1) 185 | q, k = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qk) 186 | dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale 187 | else: 188 | dots = torch.matmul(self.Q, self.K) * self.scale 189 | attn = self.attend(dots) 190 | 191 | out = torch.matmul(attn, rearrange(v, 'b n (h d) -> b h n d', h = self.heads)) 192 | out = rearrange(out, 'b h n d -> b n (h d)') 193 | if self.out_layer: 194 | return self.to_out(out) 195 | else: 196 | return out 197 | 198 | 199 | 200 | class Transformer(nn.Module): 201 | def __init__(self, dim, depth, heads, dim_head, mlp_dim, use_value=True, spatial_pe=None, spatial_x = True, init = 'none', alpha=1.0, trainable=False): 202 | super().__init__() 203 | self.norm = nn.LayerNorm(dim) 204 | self.layers = nn.ModuleList([]) 205 | for _ in range(depth): 206 | self.layers.append(nn.ModuleList([ 207 | Attention(dim, heads, dim_head, use_value, spatial_pe, spatial_x, init, alpha, trainable), 208 | FeedForward(dim, mlp_dim) 209 | ])) 210 | def forward(self, x): 211 | for attn, ff in self.layers: 212 | x = attn(x) + x 213 | x = ff(x) + x 214 | return self.norm(x) 215 | 216 | class SimpleViT(nn.Module): 217 | def __init__(self, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, channels = 3, dim_head = 64, 218 | input_pe = True, pe_choice='sin', use_value = True, spatial_pe = False, spatial_x = True, init = 'none', alpha=0.5, trainable=False): 219 | super().__init__() 220 | 221 | self.input_pe = input_pe 222 | self.use_value = use_value 223 | self.alpha = alpha 224 | if input_pe: alpha_inside = 1.0 225 | else: alpha_inside = alpha 226 | 227 | image_height, image_width = pair(image_size) 228 | patch_height, patch_width = pair(patch_size) 229 | 230 | assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.' 231 | 232 | patch_dim = channels * patch_height * patch_width 233 | 234 | self.to_patch_embedding = nn.Sequential( 235 | Rearrange("b c (h p1) (w p2) -> b (h w) (p1 p2 c)", p1 = patch_height, p2 = patch_width), 236 | nn.LayerNorm(patch_dim), 237 | nn.Linear(patch_dim, dim), 238 | nn.LayerNorm(dim), 239 | ) 240 | 241 | if pe_choice == 'sin': 242 | self.pos_embedding = posemb_sincos_2d( 243 | h = image_height // patch_height, 244 | w = image_width // patch_width, 245 | dim = dim, 246 | ) 247 | elif pe_choice == 'identity': 248 | # self.pos_embedding = torch.eye(64).repeat(1,8).type(torch.float32) 249 | s = (image_height // patch_height)*(image_width // patch_width) 250 | self.pos_embedding = torch.cat([torch.eye(s),torch.zeros(s,dim-s)],dim=-1).type(torch.float32) 251 | if spatial_pe: 252 | self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, use_value, self.pos_embedding, spatial_x, init, alpha_inside, trainable) 253 | else: # change dim_heads here 254 | self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, use_value, None, spatial_x, init, alpha_inside, trainable) 255 | 256 | self.pool = "mean" 257 | self.to_latent = nn.Identity() 258 | 259 | self.linear_head = nn.Linear(dim, num_classes) 260 | 261 | def forward(self, img): 262 | device = img.device 263 | 264 | x = self.to_patch_embedding(img) 265 | if self.input_pe: 266 | x = self.alpha*x + (1-self.alpha)*self.pos_embedding.to(device, dtype=x.dtype) 267 | 268 | x = self.transformer(x) 269 | x = x.mean(dim = 1) 270 | 271 | x = self.to_latent(x) 272 | return self.linear_head(x) --------------------------------------------------------------------------------