├── imgs
└── ConvMixer-ViT.png
├── requirements.txt
├── convmixer.sh
├── vit_pex.sh
├── LICENSE
├── util.py
├── README.md
├── mytrain_convmixer.py
├── mytrain_vit.py
├── convmixer.py
└── vit.py
/imgs/ConvMixer-ViT.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/osiriszjq/impulse_init/HEAD/imgs/ConvMixer-ViT.png
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | numpy
2 | matplotlib
3 | torch==1.13.1
4 | torchvision==0.14.1
5 | einops
6 | wandb
--------------------------------------------------------------------------------
/convmixer.sh:
--------------------------------------------------------------------------------
1 | data_path='./data'
2 | for lr in '1e-3'; do
3 | for dataset in 'cifar10' 'cifar100'; do
4 | for heads in '512'; do
5 | for init in 'random' 'softmax' 'box1' 'box25'; do
6 |
7 | # for i in 1 2 3 4 5; do
8 | python mytrain_convmixer.py --dataset ${dataset} --data_path ${data_path} --init ${init} --heads ${heads} --lr ${lr}
9 | python mytrain_convmixer.py --dataset ${dataset} --data_path ${data_path} --init ${init} --heads ${heads} --lr ${lr} --fix_spatial
10 | # done
11 |
12 | done
13 | done
14 | done
15 | done
--------------------------------------------------------------------------------
/vit_pex.sh:
--------------------------------------------------------------------------------
1 | data_path='./data'
2 | for dataset in 'cifar10' 'cifar100' 'svhn' 'tiny_imagenet'; do
3 | for lr in '1e-4'; do
4 | for init in 'impulse16_64_5_0.1_100' 'mimetic512_64' 'random512_64'; do # 'impulse16_64_5_0.1_100' 'mimetic512_64' 'random512_64'
5 | # for i in 1 2 3 4 5; do
6 | for alpha in '0.0' '0.1' '0.2' '0.3' '0.4' '0.5'; do
7 | python mytrain_vit.py --dataset ${dataset} --data_path ${data_path} --lr ${lr} --spatial_pe --spatial_x --init ${init} --use_value --trainable --alpha ${alpha} --data_aug
8 | python mytrain_vit.py --dataset ${dataset} --data_path ${data_path} --lr ${lr} --spatial_pe --spatial_x --init ${init} --use_value --trainable --alpha ${alpha}
9 | done
10 | # done
11 |
12 | done
13 | done
14 | done
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2023 Jianqiao Zheng
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/util.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import os
3 |
4 | # my srank
5 | def srank_l12(X):
6 | (u,s,v) = torch.svd(X)
7 | sr2 = (s*s).sum()/s[0]/s[0]
8 | sr1 = s.sum()/s[0]
9 | return sr1,sr2
10 |
11 |
12 | # my counting parameters
13 | def count_parameters(model):
14 | return sum(p.numel() for p in model.parameters() if p.requires_grad)
15 |
16 |
17 | # for tiny imagenet 200
18 | def create_val_img_folder(args):
19 | '''
20 | This method is responsible for separating validation images into separate sub folders
21 | '''
22 | dataset_dir = os.path.join(args.data_path, 'tiny-imagenet-200')
23 | val_dir = os.path.join(dataset_dir, 'val')
24 | img_dir = os.path.join(val_dir, 'images')
25 |
26 | fp = open(os.path.join(val_dir, 'val_annotations.txt'), 'r')
27 | data = fp.readlines()
28 | val_img_dict = {}
29 | for line in data:
30 | words = line.split('\t')
31 | val_img_dict[words[0]] = words[1]
32 | fp.close()
33 |
34 | # Create folder if not present and move images into proper folders
35 | for img, folder in val_img_dict.items():
36 | newpath = (os.path.join(img_dir, folder))
37 | if not os.path.exists(newpath):
38 | os.makedirs(newpath)
39 | if os.path.exists(os.path.join(img_dir, img)):
40 | os.rename(os.path.join(img_dir, img), os.path.join(newpath, img))
41 |
42 |
43 | def get_class_name(args):
44 | class_to_name = dict()
45 | fp = open(os.path.join(args.data_dir, args.dataset, 'words.txt'), 'r')
46 | data = fp.readlines()
47 | for line in data:
48 | words = line.strip('\n').split('\t')
49 | class_to_name[words[0]] = words[1].split(',')[0]
50 | fp.close()
51 | return class_to_name
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Convolutional Initialization for Data-Efficient Vision Transformers
2 | ### [Project Page](https://osiriszjq.github.io/impulse_init) | [Paper](https://arxiv.org/pdf/2401.12511.pdf)
3 | [](https://opensource.org/licenses/MIT)
4 |
5 |
6 | [Jianqiao Zheng](https://github.com/osiriszjq/),
7 | [Xueqian Li](https://lilac-lee.github.io/),
8 | [Simon Lucey](https://www.adelaide.edu.au/directory/simon.lucey)
9 | The University of Adelaide
10 |
11 | ---
12 | 🚀 **New! Explore our updated NeurIPS'25 work [Structured Initialization for Vision Transformers](https://github.com/osiriszjq/structured_initialization) — Released in Dec 2025**
13 |
14 | ---
15 |
16 | This is the official implementation of the paper "Convolutional Initialization for Data-Efficient Vision Transformers", including a modified version of [ConvMixer](https://arxiv.org/abs/2201.09792) and [Simple ViT](https://arxiv.org/abs/2205.01580) on CIFAR-10, CIFAR-100, SVHN and [Tiny ImageNet](http://vision.stanford.edu/teaching/cs231n/reports/2015/pdfs/yle_project.pdf). The code is based on [vision-transformers-cifar10](https://github.com/kentaroy47/vision-transformers-cifar10/tree/main)
17 |
18 | #### Illustration of different methods to extend 1D encoding
19 | 
20 |
21 |
22 | ## Google Colab
23 | [](https://github.com/osiriszjq/impulse_init/blob/main/Impulse_Initialization.ipynb)
24 | If you want to try out our new initialization for ViT, check this [Colab](https://github.com/osiriszjq/impulse_init/blob/main/Impulse_Initialization.ipynb) for a quick tour.
25 |
26 |
27 | ## Usage
28 | Modify `convmixer.sh` or `vit_pex.sh` first to change the data path and what experiments you want to run, and then just run
29 | ```
30 | bash convmixer.sh
31 | ```
32 | or
33 | ```
34 | bash vit_pex.sh
35 | ```
36 |
37 |
38 | ## Citation
39 | ```
40 | @article{zheng2024convolutional,
41 | title={Convolutional Initialization for Data-Efficient Vision Transformers},
42 | author={Zheng, Jianqiao and Li, Xueqian and Lucey, Simon},
43 | journal={arXiv preprint arXiv:2401.12511},
44 | year={2024}
45 | }
46 | ```
47 |
--------------------------------------------------------------------------------
/mytrain_convmixer.py:
--------------------------------------------------------------------------------
1 | # https://github.com/kentaroy47/vision-transformers-cifar10/blob/main/train_cifar10.py
2 |
3 | import torch
4 | import torch.nn as nn
5 | import torch.optim as optim
6 | import torch.backends.cudnn as cudnn
7 |
8 | import torchvision
9 | import torchvision.transforms as transforms
10 |
11 | import time
12 | import argparse
13 | from util import *
14 | from convmixer import *
15 |
16 |
17 | parser = argparse.ArgumentParser()
18 |
19 | parser.add_argument("--dataset", type=str, default="cifar10")
20 | parser.add_argument("--data_path", type=str, default="./data")
21 | parser.add_argument('--nowandb', action='store_true', help='disable wandb')
22 |
23 | parser.add_argument('--opt', default="adam")
24 | parser.add_argument('--scheduler', default="cos")
25 | parser.add_argument('--lr', default=1e-3, type=float, help='learning rate')
26 | parser.add_argument('--wd', default=0.0, type=float)
27 | parser.add_argument('--epochs', default=200, type=int)
28 | parser.add_argument('--batch-size', default=512, type=int)
29 | parser.add_argument('--workers', default=16, type=int)
30 | parser.add_argument('--data_aug', action='store_true')
31 | parser.add_argument('--noamp', action='store_true', help='disable mixed precision training. for older pytorch versions')
32 |
33 | parser.add_argument('--dim', default=512, type=int)
34 | parser.add_argument('--heads', default=0, type=int, help='number of different filters in one layer')
35 | parser.add_argument('--depth', default=6, type=int)
36 | parser.add_argument('--psize', default=2, type=int)
37 | parser.add_argument('--conv-ks', default=5, type=int)
38 |
39 | parser.add_argument('--fix_spatial', action='store_true', help='freeze spatial mixing')
40 | parser.add_argument("--init", type=str, default="random")
41 | parser.add_argument('--input_weight', action='store_true', help='share weights in different layers')
42 | parser.add_argument("--linear_format", action='store_true', help='use linear format conv filters')
43 | parser.add_argument("--no_spatial_bias", action='store_true', help='disable bias for spatial conv')
44 |
45 |
46 | args = parser.parse_args()
47 | args.spatial = not args.fix_spatial
48 | args.spatial_bias = not args.no_spatial_bias
49 | use_amp = not args.noamp
50 | print(args)
51 |
52 |
53 | usewandb = not args.nowandb
54 | if usewandb:
55 | import wandb
56 | watermark = "{}_h{}".format(args.init,args.heads)
57 | wandb.init(project=f'convmixer-{args.dataset}',name=watermark)
58 | wandb.config.update(args)
59 |
60 |
61 |
62 | print(f'==> Preparing {args.dataset} data..')
63 | dataset_mean = (0.4914, 0.4822, 0.4465)
64 | dataset_std = (0.2471, 0.2435, 0.2616)
65 | if args.data_aug:
66 | train_transform = transforms.Compose([
67 | transforms.RandAugment(2, 14),
68 | transforms.RandomCrop(32, scale=(1.0,1.0),ratio=(1.0,1.0)),
69 | transforms.RandomHorizontalFlip(),
70 | transforms.ToTensor(),
71 | transforms.Normalize(dataset_mean, dataset_std)
72 | ])
73 | else:
74 | train_transform = transforms.Compose([
75 | transforms.ToTensor(),
76 | transforms.Normalize(dataset_mean, dataset_std)
77 | ])
78 |
79 | test_transform = transforms.Compose([
80 | transforms.ToTensor(),
81 | transforms.Normalize(dataset_mean, dataset_std)
82 | ])
83 | if args.dataset == 'cifar10':
84 | n_class = 10
85 | image_size = 32
86 | trainset = torchvision.datasets.CIFAR10(root=args.data_path, train=True, download=True, transform=train_transform)
87 | testset = torchvision.datasets.CIFAR10(root=args.data_path, train=False, download=True, transform=test_transform)
88 | elif args.dataset == 'cifar100':
89 | n_class = 100
90 | image_size = 32
91 | trainset = torchvision.datasets.CIFAR100(root=args.data_path, train=True, download=True, transform=train_transform)
92 | testset = torchvision.datasets.CIFAR100(root=args.data_path, train=False, download=True, transform=test_transform)
93 | trainloader = torch.utils.data.DataLoader(trainset, batch_size=args.batch_size, shuffle=True, num_workers=args.workers)
94 | testloader = torch.utils.data.DataLoader(testset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers)
95 |
96 |
97 |
98 | print('==> Building model..')
99 | device = 'cuda' if torch.cuda.is_available() else 'cpu'
100 | model = ConvMixer(args.dim, args.depth, patch_size=args.psize, kernel_size=args.conv_ks, n_classes=n_class, image_size=image_size, return_embedding=False,
101 | init=args.init, heads=args.heads, spatial=args.spatial, spatial_bias=args.spatial_bias, input_weight=args.input_weight,linear_format=args.linear_format)
102 | if 'cuda' in device:
103 | print(device)
104 | print("using data parallel")
105 | model = torch.nn.DataParallel(model).cuda()
106 | cudnn.benchmark = True
107 | criterion = nn.CrossEntropyLoss()
108 |
109 | if args.opt == "adam":
110 | optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.wd)
111 | elif args.opt == "adamw":
112 | optimizer = optim.AdamW(model.parameters(), lr=args.lr, weight_decay=args.wd)
113 |
114 | if args.scheduler == "cos":
115 | scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, args.epochs)
116 | scaler = torch.cuda.amp.GradScaler(enabled=use_amp)
117 |
118 | num_param = count_parameters(model)
119 | for epoch in range(args.epochs):
120 | start = time.time()
121 | train_loss, train_acc, n = 0, 0, 0
122 | n_batch = 0
123 | for i, (X, y) in enumerate(trainloader):
124 | model.train()
125 | X, y = X.cuda(), y.cuda()
126 |
127 | with torch.cuda.amp.autocast(enabled=use_amp):
128 | output = model(X)
129 | loss = criterion(output, y)
130 | scaler.scale(loss).backward()
131 | scaler.step(optimizer)
132 | scaler.update()
133 | optimizer.zero_grad()
134 |
135 | train_loss += loss.item() * y.size(0)
136 | train_acc += (output.max(1)[1] == y).sum().item()
137 | n += y.size(0)
138 | train_acc = train_acc/n
139 | train_loss = train_loss/n
140 |
141 |
142 | model.eval()
143 | test_acc, m = 0, 0
144 | with torch.no_grad():
145 | for i, (X, y) in enumerate(testloader):
146 | X, y = X.cuda(), y.cuda()
147 | with torch.cuda.amp.autocast():
148 | output = model(X)
149 | test_acc += (output.max(1)[1] == y).sum().item()
150 | m += y.size(0)
151 | test_acc = test_acc/m
152 | scheduler.step()
153 |
154 | if usewandb:
155 | wandb.log({'epoch': epoch, 'train_loss': train_loss, 'train_acc': train_acc, "val_acc": test_acc, "lr": optimizer.param_groups[0]["lr"],
156 | "epoch_time": time.time()-start, 'num_param':num_param})
157 | else:
158 | print(f'epoch: {epoch}, train_loss: {train_loss:.4f}, train_acc: {train_acc:.4f}, val_acc: {test_acc:.4f}, lr: {optimizer.param_groups[0]["lr"]:.6f}, epoch_time: {time.time()-start:.1f}, num_param:{num_param}')
--------------------------------------------------------------------------------
/mytrain_vit.py:
--------------------------------------------------------------------------------
1 | # https://github.com/kentaroy47/vision-transformers-cifar10/blob/main/train_cifar10.py
2 |
3 | import torch
4 | import torch.nn as nn
5 | import torch.optim as optim
6 | import torch.backends.cudnn as cudnn
7 |
8 | import torchvision
9 | import torchvision.transforms as transforms
10 |
11 | import time
12 | import argparse
13 | from util import *
14 | from vit import *
15 |
16 |
17 | parser = argparse.ArgumentParser()
18 |
19 | parser.add_argument("--dataset", type=str, default="cifar10")
20 | parser.add_argument("--data_path", type=str, default="./data")
21 | parser.add_argument('--nowandb', action='store_true', help='disable wandb')
22 |
23 | parser.add_argument('--opt', default="adam")
24 | parser.add_argument('--scheduler', default="cos")
25 | parser.add_argument('--lr', default=1e-4, type=float, help='learning rate')
26 | parser.add_argument('--wd', default=0.0, type=float)
27 | parser.add_argument('--epochs', default=200, type=int)
28 | parser.add_argument('--batch-size', default=512, type=int)
29 | parser.add_argument('--num_workers', default=16, type=int)
30 | parser.add_argument('--data_aug', action='store_true')
31 | parser.add_argument('--noamp', action='store_true', help='disable mixed precision training. for older pytorch versions')
32 |
33 | parser.add_argument('--dim', default=512, type=int)
34 | parser.add_argument('--heads', default=8, type=int)
35 | parser.add_argument('--depth', default=6, type=int)
36 | parser.add_argument('--psize', default=2, type=int)
37 | parser.add_argument('--mlp_dim', default=512, type=int)
38 | parser.add_argument('--dim_head', default=64, type=int)
39 |
40 | parser.add_argument('--input_pe', action='store_true', help='use pe at input')
41 | parser.add_argument("--init", type=str, default="none")
42 | parser.add_argument("--pe_choice", type=str, default="sin")
43 | parser.add_argument('--use_value', action='store_true', help='use value')
44 | parser.add_argument("--spatial_pe", action='store_true', help='use pe for spatial mixing')
45 | parser.add_argument("--spatial_x", action='store_true', help='use x for spatial mixing')
46 | parser.add_argument('--alpha', default=0.5, type=float, help='balance pe and x')
47 | parser.add_argument("--trainable", action='store_true', help='let spatial mixing to be trainable')
48 |
49 |
50 | args = parser.parse_args()
51 | use_amp = not args.noamp
52 | print(args)
53 |
54 |
55 | usewandb = not args.nowandb
56 | if usewandb:
57 | import wandb
58 | watermark = "{}_h{}".format(args.init,args.heads)
59 | wandb.init(project=f'vit-{args.dataset}', name=watermark)
60 | wandb.config.update(args)
61 |
62 |
63 | print(f'==> Preparing {args.dataset} data..')
64 | if args.dataset[:5] == 'cifar':
65 | image_size = 32
66 | dataset_mean = (0.4914, 0.4822, 0.4465)
67 | dataset_std = (0.2471, 0.2435, 0.2616)
68 | elif args.dataset == 'svhn':
69 | image_size = 32
70 | dataset_mean = (0.4376821, 0.4437697, 0.47280442)
71 | dataset_std = (0.19803012, 0.20101562, 0.19703614)
72 | elif args.dataset == 'tiny_imagenet':
73 | image_size = 64
74 | args.psize = 4
75 | print(args)
76 | dataset_mean = (0.485, 0.456, 0.406)
77 | dataset_std = (0.229, 0.224, 0.225)
78 | else:
79 | print('no available dataset')
80 | if args.data_aug:
81 | train_transform = transforms.Compose([
82 | transforms.RandAugment(2, 14),
83 | transforms.RandomHorizontalFlip(),
84 | transforms.RandomResizedCrop(image_size),
85 | transforms.ToTensor(),
86 | transforms.Normalize(dataset_mean, dataset_std)
87 | ])
88 | else:
89 | train_transform = transforms.Compose([
90 | transforms.Resize(image_size),
91 | transforms.ToTensor(),
92 | transforms.Normalize(dataset_mean, dataset_std)
93 | ])
94 |
95 | test_transform = transforms.Compose([
96 | transforms.Resize(image_size),
97 | transforms.ToTensor(),
98 | transforms.Normalize(dataset_mean, dataset_std)
99 | ])
100 | if args.dataset == 'cifar10':
101 | n_class = 10
102 | trainset = torchvision.datasets.CIFAR10(root=args.data_path, train=True, download=True, transform=train_transform)
103 | testset = torchvision.datasets.CIFAR10(root=args.data_path, train=False, download=True, transform=test_transform)
104 | elif args.dataset == 'cifar100':
105 | n_class = 100
106 | trainset = torchvision.datasets.CIFAR100(root=args.data_path, train=True, download=True, transform=train_transform)
107 | testset = torchvision.datasets.CIFAR100(root=args.data_path, train=False, download=True, transform=test_transform)
108 | elif args.dataset == 'svhn':
109 | n_class = 10
110 | trainset = torchvision.datasets.SVHN(root=args.data_path, split='train', download=True, transform=train_transform)
111 | testset = torchvision.datasets.SVHN(root=args.data_path, split='test', download=True, transform=test_transform)
112 | elif args.dataset == 'tiny_imagenet':
113 | trainset = torchvision.datasets.ImageFolder(root=args.data_path+'/tiny-imagenet-200/train', transform=train_transform)
114 | testset = torchvision.datasets.ImageFolder(root=args.data_path+'/tiny-imagenet-200/val/images', transform=test_transform)
115 | trainloader = torch.utils.data.DataLoader(trainset, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers, pin_memory=True)
116 | testloader = torch.utils.data.DataLoader(testset, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers, pin_memory=True)
117 |
118 |
119 |
120 | print('==> Building model..')
121 | device = 'cuda' if torch.cuda.is_available() else 'cpu'
122 | model = SimpleViT(image_size=image_size, patch_size=args.psize, num_classes=n_class, dim=args.dim, depth=args.depth, heads=args.heads, mlp_dim=args.mlp_dim, dim_head=args.dim_head,
123 | input_pe=args.input_pe, pe_choice=args.pe_choice, use_value=args.use_value, spatial_pe=args.spatial_pe, spatial_x=args.spatial_x, init=args.init, alpha=args.alpha, trainable=args.trainable)
124 | if 'cuda' in device:
125 | print(device)
126 | print("using data parallel")
127 | model = torch.nn.DataParallel(model).cuda()
128 | cudnn.benchmark = True
129 | criterion = nn.CrossEntropyLoss()
130 |
131 | if args.opt == "adam":
132 | optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.wd)
133 | elif args.opt == "adamw":
134 | optimizer = optim.AdamW(model.parameters(), lr=args.lr, weight_decay=args.wd)
135 |
136 | if args.scheduler == "cos":
137 | scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, args.epochs)
138 | scaler = torch.cuda.amp.GradScaler(enabled=use_amp)
139 |
140 | num_param = count_parameters(model)
141 | for epoch in range(args.epochs):
142 | start = time.time()
143 | train_loss, train_acc, n = 0, 0, 0
144 | for i, (X, y) in enumerate(trainloader):
145 | model.train()
146 | X, y = X.cuda(), y.cuda()
147 | with torch.cuda.amp.autocast(enabled=use_amp):
148 | output = model(X)
149 | loss = criterion(output, y)
150 | scaler.scale(loss).backward()
151 | scaler.step(optimizer)
152 | scaler.update()
153 | optimizer.zero_grad()
154 |
155 | train_loss += loss.item() * y.size(0)
156 | train_acc += (output.max(1)[1] == y).sum().item()
157 | n += y.size(0)
158 | train_acc = train_acc/n
159 | train_loss = train_loss/n
160 |
161 |
162 | model.eval()
163 | test_acc, m = 0, 0
164 | with torch.no_grad():
165 | for i, (X, y) in enumerate(testloader):
166 | X, y = X.cuda(), y.cuda()
167 | with torch.cuda.amp.autocast():
168 | output = model(X)
169 | test_acc += (output.max(1)[1] == y).sum().item()
170 | m += y.size(0)
171 | test_acc = test_acc/m
172 | scheduler.step()
173 |
174 | if usewandb:
175 | wandb.log({'epoch': epoch, 'train_loss': train_loss, 'train_acc': train_acc, "val_acc": test_acc, "lr": optimizer.param_groups[0]["lr"],
176 | "epoch_time": time.time()-start, 'num_param':num_param})
177 | else:
178 | print(f'epoch: {epoch}, train_loss: {train_loss:.4f}, train_acc: {train_acc:.4f}, val_acc: {test_acc:.4f}, lr: {optimizer.param_groups[0]["lr"]:.6f}, epoch_time: {time.time()-start:.1f}, num_param:{num_param}')
--------------------------------------------------------------------------------
/convmixer.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import numpy as np
4 |
5 |
6 | # initialization of convolusion kernel, C*1*K*K
7 | def SpatialConv2d_init(C, kernel_size, init='random'):
8 | if (init == 'random')|(init == 'softmax'):
9 | weight = 1/kernel_size*(2*torch.rand((C,1,kernel_size,kernel_size))-1)
10 | elif init == 'impulse':
11 | k = torch.randint(0,kernel_size*kernel_size,(C,1))
12 | weight = torch.zeros((C,1,kernel_size*kernel_size))
13 | for i in range(C):
14 | for j in range(1):
15 | weight[i,j,k[i,j]] = 1
16 | weight = np.sqrt(1/3)*weight.reshape(C,1,kernel_size,kernel_size)
17 | elif init[:3] == 'box':
18 | weight = torch.zeros((C,1,kernel_size*kernel_size))
19 | for i in range(C):
20 | for j in range(1):
21 | k = np.random.choice(kernel_size*kernel_size,int(init[3:]),replace=False)
22 | weight[i,j,k] = 1
23 | weight = np.sqrt(1/int(init[3:])/3)*weight.reshape(C,1,kernel_size,kernel_size)
24 | elif init[:3] == 'gau':
25 | k = torch.randint(0,kernel_size,(C,1,2))
26 | weight = torch.zeros((C,1,kernel_size,kernel_size))
27 | for i in range(C):
28 | for j in range(1):
29 | for p in range(kernel_size):
30 | for q in range(kernel_size):
31 | weight[i,j,p,q] = (-0.5/float(init[3:])*((p-k[i,j,0])**2+(q-k[i,j,1])**2)).exp()
32 | weight = weight/((weight.flatten(1,3)**2).sum(1).mean()*3).sqrt()
33 | else:
34 | return -1
35 | return weight
36 |
37 |
38 | # initialization of convolusion kernel in linear format, C*img_size*img_size, Out of Memory!!!
39 | def SpatialConv2d_Linear_init(C, H, W, init='perm'):
40 | weight = torch.zeros((C,H*W,H*W))
41 | if init == 'perm':
42 | for i in range(C):
43 | k = torch.randint(0,H*W,(H*W,))
44 | for j in range(H*W):
45 | weight[i][j,k[j]] = 1
46 | weight = np.sqrt(1/3)*weight
47 | elif init == 'fullperm':
48 | for i in range(C):
49 | k = torch.randperm(H*W)
50 | for j in range(H*W):
51 | weight[i][j,k[j]] = 1
52 | weight = np.sqrt(1/3)*weight
53 | elif init[:6] == 'impulse':
54 | ff = int(init[7:])
55 | weight = torch.zeros((C,H*W,H*W))
56 | k = torch.randint(0,ff**2,(C,))
57 | for i in range(C):
58 | m = (k[i]//ff)-(ff//2)
59 | n = (k[i]%ff)-(ff//2)
60 | tmp_weight = torch.zeros((W,W))
61 | for j in range(0-min(0,n),W-max(0,n)):
62 | tmp_weight[j,j+n] = 1
63 | for j in range(0-min(0,m),H-max(0,m)):
64 | weight[i,j*W:(j+1)*W,(j+m)*W:(j+m+1)*W] = tmp_weight
65 | weight = np.sqrt(1/3)*weight
66 | else:
67 | return -1
68 | return weight
69 |
70 |
71 |
72 | # my spatial conv fuction, group=#channels, heads controls the number of different conv filters
73 | class SpatialConv2d(nn.Module):
74 | def __init__(self, C, kernel_size, bias=True, init='random', heads = -1, trainable= True, input_weight=None):
75 | super(SpatialConv2d, self).__init__()
76 | self.C = C
77 | self.kernel_size = kernel_size
78 | self.init = init
79 |
80 | # different initialisation
81 | weight = SpatialConv2d_init(C,kernel_size,init=init)
82 |
83 | # how many heads or different filters we want to use
84 | if (heads<1)|(heads>C) :
85 | heads = C
86 | self.choice_idx = np.random.choice(heads,C,replace=(headsC) :
124 | heads = C
125 | self.choice_idx = np.random.choice(heads,C,replace=(headsdim) :
175 | heads = dim
176 | else:
177 | self.input_weight = None
178 |
179 | # choose spatial conv format
180 | if linear_format:
181 | H=int(image_size/patch_size)
182 | W=int(image_size/patch_size)
183 | if input_weight:
184 | self.input_weight = SpatialConv2d_Linear_init(dim, H, W, init=init)
185 | self.input_weight = nn.Parameter(self.input_weight[:heads],requires_grad=spatial)
186 | for _ in range(depth):
187 | self.mixer.append(nn.ModuleList([
188 | Residual(nn.Sequential(
189 | SpatialConv2d_LinearFormat(dim, H, W, bias=spatial_bias, init=init, heads= heads, trainable=spatial, input_weight=self.input_weight),
190 | nn.GELU(), nn.BatchNorm2d(dim))),
191 | nn.Sequential(nn.Conv2d(dim, dim, kernel_size=1), nn.GELU(), nn.BatchNorm2d(dim))
192 | ]))
193 | else:
194 | if input_weight:
195 | self.input_weight = SpatialConv2d_init(dim, kernel_size, init=init)
196 | self.input_weight = nn.Parameter(self.input_weight[:heads],requires_grad=spatial)
197 | for _ in range(depth):
198 | self.mixer.append(nn.ModuleList([
199 | Residual(nn.Sequential(
200 | SpatialConv2d(dim, kernel_size, bias=spatial_bias, init=init, heads= heads, trainable=spatial, input_weight=self.input_weight),
201 | nn.BatchNorm2d(dim))),
202 | # missing GeLU here !!!!
203 | nn.Sequential(nn.Conv2d(dim, dim, kernel_size=1),nn.GELU(),nn.BatchNorm2d(dim))
204 | ]))
205 |
206 | self.output = nn.Sequential(nn.AdaptiveAvgPool2d((1,1)),
207 | nn.Flatten(),
208 | nn.Linear(dim, n_classes))
209 |
210 | def forward(self, x):
211 | if self.return_embeding: xs = [x]
212 | x = self.input_embedding(x)
213 | if self.return_embeding: xs.append(x)
214 | for spatial,channel in self.mixer:
215 | x = spatial(x)
216 | x = channel(x)
217 | if self.return_embeding: xs.append(x)
218 | x = self.output(x)
219 | if self.return_embeding:
220 | return x, xs
221 | else:
222 | return x
--------------------------------------------------------------------------------
/vit.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | import torch.optim as optim
4 |
5 | import numpy as np
6 | from einops import rearrange
7 | from einops.layers.torch import Rearrange
8 |
9 | # helpers
10 |
11 | def pair(t):
12 | return t if isinstance(t, tuple) else (t, t)
13 |
14 | def posemb_sincos_2d(h, w, dim, temperature: int = 10000, dtype = torch.float32):
15 | y, x = torch.meshgrid(torch.arange(h), torch.arange(w), indexing="ij")
16 | assert (dim % 4) == 0, "feature dimension must be multiple of 4 for sincos emb"
17 | omega = torch.arange(dim // 4) / (dim // 4 - 1)
18 | omega = 1.0 / (temperature ** omega)
19 |
20 | y = y.flatten()[:, None] * omega[None, :]
21 | x = x.flatten()[:, None] * omega[None, :]
22 | pe = torch.cat((x.sin(), x.cos(), y.sin(), y.cos()), dim=1)
23 | return pe.type(dtype)
24 |
25 |
26 | # my impulse initilization function
27 | def impulse_init(heads,img_size,att_rank,ff,scale=1.0,spatial_pe=None,norm=1):
28 | weight = torch.zeros((heads,img_size**2,img_size**2))
29 | k = torch.randint(0,ff**2,(heads,))
30 | for i in range(heads):
31 | m = (k[i]//ff)-(ff//2)
32 | n = (k[i]%ff)-(ff//2)
33 | tmp_weight = torch.zeros((img_size,img_size))
34 | for j in range(0-min(0,n),img_size-max(0,n)):
35 | tmp_weight[j,j+n] = 1
36 | for j in range(0-min(0,m),img_size-max(0,m)):
37 | weight[i,j*img_size:(j+1)*img_size,(j+m)*img_size:(j+m+1)*img_size] = tmp_weight
38 | # weight = np.sqrt(1/3)*weight
39 | class PermuteM(nn.Module):
40 | def __init__(self, heads, img_size, att_rank,scale=1.0,spatial_pe=None):
41 | super().__init__()
42 | self.scale = scale
43 | if spatial_pe is None:
44 | self.spatial_pe = False
45 | weights_Q = np.sqrt(1/att_rank/heads)*(2*torch.rand(heads,img_size,att_rank)-1)
46 | weights_K = np.sqrt(1/att_rank/heads)*(2*torch.rand(heads,att_rank,img_size)-1)
47 | else:
48 | self.spatial_pe = True
49 | self.pe = spatial_pe.cuda()
50 | weights_Q = np.sqrt(1/att_rank/heads)*(2*torch.rand(heads,spatial_pe.shape[1],att_rank)-1)
51 | weights_K = np.sqrt(1/att_rank/heads)*(2*torch.rand(heads, att_rank, spatial_pe.shape[1])-1)
52 |
53 | self.weights_K = nn.Parameter(weights_K)
54 | self.weights_Q = nn.Parameter(weights_Q)
55 | def forward(self):
56 | if self.spatial_pe:
57 | M = self.pe@self.weights_Q@self.weights_K@(self.pe.T)
58 | else:
59 | M = torch.bmm(self.weights_Q,self.weights_K)
60 | return torch.softmax(M*self.scale,-1)
61 |
62 | net = PermuteM(heads,img_size**2,att_rank,scale,spatial_pe)
63 | net.cuda()
64 |
65 | nq = net.weights_Q.detach().cpu().norm(dim=(1)).mean()
66 | weight = weight.cuda()
67 | num_epoch = 10000
68 | criterion = nn.MSELoss()
69 | optimizer = optim.Adam(net.parameters(), lr=0.0001)#,weight_decay=1e-6)
70 | for i in range(num_epoch):
71 | if i%norm==0:
72 | with torch.no_grad():
73 | net.weights_Q.div_(net.weights_Q.detach().norm(dim=(1),keepdim=True)/nq)
74 | net.weights_K.div_(net.weights_K.detach().norm(dim=(1),keepdim=True)/nq)
75 | optimizer.zero_grad()
76 | outputs = net()
77 | loss = criterion(outputs, weight)
78 | loss.backward()
79 | optimizer.step()
80 | print(loss.data)
81 |
82 | return net.weights_Q.detach().cpu(),net.weights_K.detach().cpu()
83 |
84 | # classes
85 |
86 | class FeedForward(nn.Module):
87 | def __init__(self, dim, hidden_dim):
88 | super().__init__()
89 | self.net = nn.Sequential(
90 | nn.LayerNorm(dim),
91 | nn.Linear(dim, hidden_dim),
92 | nn.GELU(),
93 | nn.Linear(hidden_dim, dim),
94 | )
95 | def forward(self, x):
96 | return self.net(x)
97 |
98 | class Attention(nn.Module):
99 | def __init__(self, dim, heads = 8, dim_head = 64, use_value = True, spatial_pe = None,
100 | spatial_x = True, init = 'none', alpha=1.0, trainable=True, out_layer=True):
101 | super().__init__()
102 | inner_dim = dim_head * heads
103 | self.scale = dim_head ** -0.5
104 | self.heads = heads
105 | self.norm = nn.LayerNorm(dim)
106 | self.attend = nn.Softmax(dim = -1)
107 |
108 | self.alpha = alpha
109 |
110 | # input to q&k
111 | self.spatial_x = spatial_x
112 | self.spatial_pe = False
113 | if spatial_pe is not None:
114 | self.spatial_pe = True
115 | self.pos_embedding = spatial_pe
116 |
117 | # format & initilization of q&k
118 | self.init = init
119 | if init == 'none':
120 | self.to_qk = nn.Linear(dim, inner_dim*2, bias = False)
121 | else:
122 | if init[:7] == 'impulse':
123 | a, b, c, d, e = init[7:].split('_')
124 | img_size = int(a)
125 | att_rank = int(b)
126 | ff = int(c)
127 | self.scale = float(d)
128 | norm = int(e)
129 | Q, K = impulse_init(heads,img_size,att_rank,ff,self.scale,spatial_pe,norm)
130 | elif init[:6] == 'random':
131 | a, b = init[6:].split('_')
132 | img_size = int(a)
133 | att_rank = int(b)
134 | Q = np.sqrt(1/img_size)*(2*torch.rand(heads,img_size,att_rank)-1)
135 | K = np.sqrt(1/img_size)*(2*torch.rand(heads,att_rank,img_size)-1)
136 | elif init[:7] == 'mimetic':
137 | a, b = init[7:].split('_')
138 | img_size = int(a)
139 | att_rank = int(b)
140 | W = 0.7*np.sqrt(1/img_size)*(2*torch.rand(heads,img_size,img_size)-1)+0.7*torch.eye(img_size).unsqueeze(0).repeat(heads,1,1)
141 | U,s,V = torch.linalg.svd(W)
142 | s_2 = torch.sqrt(s)
143 | Q = torch.matmul(U[:,:,:att_rank], torch.diag_embed(s_2)[:,:att_rank,:att_rank])
144 | K = torch.matmul(torch.diag_embed(s_2)[:,:att_rank,:att_rank], V[:,:att_rank,:])
145 | if self.spatial_pe|self.spatial_x:
146 | print('use linear format')
147 | self.to_qk = nn.Linear(dim, inner_dim*2, bias = False)
148 | self.to_qk.weight.data[:512,:] = rearrange(Q, 'h n d -> n (h d)').T
149 | self.to_qk.weight.data[512:,:] = rearrange(K, 'h d n -> n (h d)').T
150 | else:
151 | print('use Q K format')
152 | self.Q = nn.Parameter(Q,requires_grad=trainable)
153 | self.K = nn.Parameter(K,requires_grad=trainable)
154 |
155 | # use v or just use x
156 | self.use_value = use_value
157 | if use_value:
158 | self.to_v = nn.Linear(dim, inner_dim, bias = False)
159 |
160 | # use output layer or not
161 | self.out_layer = out_layer
162 | if self.out_layer:
163 | self.to_out = nn.Linear(inner_dim, dim, bias = False)
164 |
165 |
166 |
167 | def forward(self, x):
168 | x = self.norm(x)
169 |
170 | # use v or just use x
171 | if self.use_value:
172 | v = self.to_v(x)
173 | else:
174 | v = x
175 |
176 | # q&k format
177 | if self.spatial_pe|self.spatial_x:
178 | # input to q&v
179 | device = x.device
180 | if self.spatial_pe&self.spatial_x:
181 | x = self.alpha*x + (1-self.alpha)*self.pos_embedding.to(device, dtype=x.dtype)
182 | elif self.spatial_pe:
183 | x = 0*x + self.pos_embedding.to(device, dtype=x.dtype)
184 | qk = self.to_qk(x).chunk(2, dim = -1)
185 | q, k = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qk)
186 | dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
187 | else:
188 | dots = torch.matmul(self.Q, self.K) * self.scale
189 | attn = self.attend(dots)
190 |
191 | out = torch.matmul(attn, rearrange(v, 'b n (h d) -> b h n d', h = self.heads))
192 | out = rearrange(out, 'b h n d -> b n (h d)')
193 | if self.out_layer:
194 | return self.to_out(out)
195 | else:
196 | return out
197 |
198 |
199 |
200 | class Transformer(nn.Module):
201 | def __init__(self, dim, depth, heads, dim_head, mlp_dim, use_value=True, spatial_pe=None, spatial_x = True, init = 'none', alpha=1.0, trainable=False):
202 | super().__init__()
203 | self.norm = nn.LayerNorm(dim)
204 | self.layers = nn.ModuleList([])
205 | for _ in range(depth):
206 | self.layers.append(nn.ModuleList([
207 | Attention(dim, heads, dim_head, use_value, spatial_pe, spatial_x, init, alpha, trainable),
208 | FeedForward(dim, mlp_dim)
209 | ]))
210 | def forward(self, x):
211 | for attn, ff in self.layers:
212 | x = attn(x) + x
213 | x = ff(x) + x
214 | return self.norm(x)
215 |
216 | class SimpleViT(nn.Module):
217 | def __init__(self, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, channels = 3, dim_head = 64,
218 | input_pe = True, pe_choice='sin', use_value = True, spatial_pe = False, spatial_x = True, init = 'none', alpha=0.5, trainable=False):
219 | super().__init__()
220 |
221 | self.input_pe = input_pe
222 | self.use_value = use_value
223 | self.alpha = alpha
224 | if input_pe: alpha_inside = 1.0
225 | else: alpha_inside = alpha
226 |
227 | image_height, image_width = pair(image_size)
228 | patch_height, patch_width = pair(patch_size)
229 |
230 | assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.'
231 |
232 | patch_dim = channels * patch_height * patch_width
233 |
234 | self.to_patch_embedding = nn.Sequential(
235 | Rearrange("b c (h p1) (w p2) -> b (h w) (p1 p2 c)", p1 = patch_height, p2 = patch_width),
236 | nn.LayerNorm(patch_dim),
237 | nn.Linear(patch_dim, dim),
238 | nn.LayerNorm(dim),
239 | )
240 |
241 | if pe_choice == 'sin':
242 | self.pos_embedding = posemb_sincos_2d(
243 | h = image_height // patch_height,
244 | w = image_width // patch_width,
245 | dim = dim,
246 | )
247 | elif pe_choice == 'identity':
248 | # self.pos_embedding = torch.eye(64).repeat(1,8).type(torch.float32)
249 | s = (image_height // patch_height)*(image_width // patch_width)
250 | self.pos_embedding = torch.cat([torch.eye(s),torch.zeros(s,dim-s)],dim=-1).type(torch.float32)
251 | if spatial_pe:
252 | self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, use_value, self.pos_embedding, spatial_x, init, alpha_inside, trainable)
253 | else: # change dim_heads here
254 | self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, use_value, None, spatial_x, init, alpha_inside, trainable)
255 |
256 | self.pool = "mean"
257 | self.to_latent = nn.Identity()
258 |
259 | self.linear_head = nn.Linear(dim, num_classes)
260 |
261 | def forward(self, img):
262 | device = img.device
263 |
264 | x = self.to_patch_embedding(img)
265 | if self.input_pe:
266 | x = self.alpha*x + (1-self.alpha)*self.pos_embedding.to(device, dtype=x.dtype)
267 |
268 | x = self.transformer(x)
269 | x = x.mean(dim = 1)
270 |
271 | x = self.to_latent(x)
272 | return self.linear_head(x)
--------------------------------------------------------------------------------