├── models ├── __init__.py ├── embeddings.py ├── base.py ├── gpt.py ├── attentions.py ├── blocks.py └── vit.py ├── utils.py ├── vision_utils.py ├── README.md ├── nlp_utils.py ├── vit.py └── gpt.py /models/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /models/embeddings.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | 5 | class PatchEmbed2d(nn.Module): 6 | def __init__(self, 7 | patch_size: int or tuple, 8 | emb_dim: int, 9 | in_channels: int 10 | ): 11 | super().__init__() 12 | self.proj = nn.Conv2d(in_channels, emb_dim, kernel_size=patch_size, stride=patch_size) 13 | 14 | def forward(self, 15 | input: torch.Tensor 16 | ) -> torch.Tensor: 17 | # input: BxCxHxW -> BxNxC' 18 | return self.proj(input).flatten(2).transpose(1, 2) 19 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | from functools import wraps 2 | from typing import Callable, Optional 3 | 4 | from homura import init_distributed, is_distributed 5 | 6 | 7 | def distributed_ready_main(func: Callable = None, 8 | backend: Optional[str] = None, 9 | init_method: Optional[str] = None, 10 | disable_distributed_print: str = False 11 | ) -> Callable: 12 | """ Wrap a main function to make it distributed ready 13 | """ 14 | 15 | if is_distributed(): 16 | init_distributed(backend=backend, init_method=init_method, disable_distributed_print=disable_distributed_print) 17 | 18 | @wraps(func) 19 | def inner(*args, **kwargs): 20 | return func(*args, **kwargs) 21 | 22 | return inner 23 | -------------------------------------------------------------------------------- /vision_utils.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple 2 | 3 | import torch 4 | 5 | 6 | def fast_collate(batch: list 7 | ) -> Tuple[torch.Tensor, torch.Tensor]: 8 | # based on NVidia's Apex 9 | # but it's not faster than the default probably because transforms.ToTensor is too slow. 10 | imgs = torch.stack([img for img, target in batch], dim=0) 11 | targets = torch.tensor([target for img, target in batch], dtype=torch.int64) 12 | return imgs, targets 13 | 14 | 15 | def gen_mixup_collate(alpha): 16 | # see https://github.com/moskomule/mixup.pytorch 17 | beta = torch.distributions.Beta(alpha + 1, alpha) 18 | 19 | def f(batch): 20 | tensors, targets = fast_collate(batch) 21 | indices = torch.randperm(tensors.size(0)) 22 | _tensors = tensors.clone()[indices] 23 | gamma = beta.sample() 24 | tensors.mul_(gamma).add_(_tensors, alpha=1 - gamma) 25 | return tensors, targets 26 | 27 | return f 28 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Simple Transformers 2 | 3 | Simple transformer implementations that I can understand. 4 | 5 | ## Requirements 6 | 7 | ```commandline 8 | conda create -n transformer python=3.9 9 | conda activate transformer 10 | conda install -c pytorch -c conda-forge pytorch torchvision cudatoolkit=11.1 11 | pip install -U homura-core chika rich 12 | # for NLP also install 13 | pip install -U datasets tokenizers 14 | # To use checkpointing 15 | pip install -U fairscale 16 | # To accelerate (probably only slightly), 17 | pip install -U opt_einsum 18 | ``` 19 | 20 | ## Examples 21 | 22 | ### Language Modeling 23 | 24 | #### GPT 25 | 26 | Train GPT-like models on wikitext or GigaText. Currently, Transformer blocks of improved pre LN, pre LN, and post LN are 27 | available for comparison. 28 | 29 | ```commandline 30 | python gpt.py [--model.block {ipre_ln, pre_ln, post_ln}] [--amp] 31 | ``` 32 | 33 | #### Bert 34 | 35 | Work in progress 36 | 37 | ### Image Recognition 38 | 39 | Train ImageNet classification models. 40 | 41 | Currently, ViT, and CaiT are implemented. 42 | 43 | ```commandline 44 | # single process training 45 | python vit.py [--amp] [--model.ema] 46 | # for multi-process training, 47 | python -m torch.distributed.launch --nproc_per_node=2 vit.py ... 48 | ``` 49 | 50 | ## Acknowledgement 51 | 52 | For this project, I learned a lot from Andrej's [minGPT](https://github.com/karpathy/mingpt), 53 | Ross's [timm](https://github.com/rwightman/pytorch-image-models), and 54 | FAIR's [ClassyVision](https://github.com/facebookresearch/ClassyVision). 55 | -------------------------------------------------------------------------------- /models/base.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from torch import nn 4 | from torch.nn.modules.conv import _ConvNd 5 | 6 | 7 | class TransformerBase(nn.Module): 8 | """ Baseclass supports checkpointing and weight_decay branching 9 | 10 | """ 11 | 12 | def __init__(self, 13 | blocks: nn.Sequential, 14 | checkpointing: bool): 15 | super().__init__() 16 | self._blocks = blocks 17 | self._blocks_train = self._blocks 18 | if checkpointing: 19 | from fairscale.nn.misc import checkpoint_wrapper 20 | self._blocks_train = checkpoint_wrapper(self._blocks) 21 | 22 | @property 23 | def blocks(self): 24 | if self.training: 25 | return self._blocks_train 26 | else: 27 | return self._blocks 28 | 29 | def init_weights(self): 30 | raise NotImplementedError 31 | 32 | @property 33 | def param_groups(self 34 | ) -> dict[str, list]: 35 | 36 | decay = set() 37 | no_decay = set() 38 | apply_decay = (nn.Linear, _ConvNd) 39 | no_apply_decay = (nn.LayerNorm, nn.Embedding) 40 | for name, param in self.named_parameters(): 41 | if "pos_emb" in name or "token" in name: 42 | no_decay.add(param) 43 | for module in self.modules(): 44 | if isinstance(module, no_apply_decay): 45 | for param in module.parameters(): 46 | no_decay.add(param) 47 | elif isinstance(module, apply_decay): 48 | decay.add(module.weight) 49 | if module.bias is not None: 50 | no_decay.add(module.bias) 51 | for param in self.parameters(): 52 | if param not in no_decay: 53 | decay.add(param) 54 | assert len([param for param in self.parameters()]) == len(decay) + len(no_decay) 55 | return {"decay": list(decay), "no_decay": list(no_decay)} 56 | 57 | @classmethod 58 | def construct(cls, 59 | *args, 60 | **kwargs 61 | ) -> TransformerBase: 62 | raise NotImplementedError 63 | -------------------------------------------------------------------------------- /nlp_utils.py: -------------------------------------------------------------------------------- 1 | import pathlib 2 | from os import environ 3 | 4 | import datasets 5 | from datasets.utils.logging import set_verbosity_error 6 | from tokenizers import Tokenizer, decoders, models, normalizers, pre_tokenizers, processors, trainers 7 | from torch.utils.data import DataLoader 8 | 9 | 10 | def get_data(name, 11 | batch_size, 12 | max_len, 13 | num_workers=4, 14 | train_full=False 15 | ): 16 | max_len += 1 17 | datasets.disable_progress_bar() 18 | set_verbosity_error() 19 | 20 | # followed https://github.com/EleutherAI/gpt-neo/ 21 | environ['TOKENIZERS_PARALLELISM'] = 'true' 22 | tokenizer_path = pathlib.Path(f"{name}_tokenizer{max_len}.json") 23 | _name = {"wikitext": ("wikitext", "wikitext-103-v1"), 24 | "gigaword": ("gigaword",) 25 | }[name] 26 | _column_name = {"wikitext": "text", 27 | "gigaword": "document" 28 | }[name] 29 | if tokenizer_path.exists(): 30 | tokenizer = Tokenizer.from_file(str(tokenizer_path)) 31 | else: 32 | dataset = datasets.load_dataset(*_name, split="train+test+validation") 33 | tokenizer = Tokenizer(models.BPE(unk_token="")) 34 | tokenizer.pre_tokenizer = pre_tokenizers.ByteLevel(add_prefix_space=True) 35 | tokenizer.decoder = decoders.ByteLevel() 36 | tokenizer.normalizer = normalizers.NFKC() 37 | tokenizer.post_processor = processors.ByteLevel(trim_offsets=True) 38 | tokenizer.enable_truncation(max_length=max_len) 39 | tokenizer.enable_padding(length=max_len) 40 | trainer = trainers.BpeTrainer(min_frequency=2, special_tokens=["", "", "", ]) 41 | 42 | def batch_iterator(bs): 43 | for i in range(0, len(dataset), bs): 44 | yield dataset[i: i + bs][_column_name] 45 | 46 | tokenizer.train_from_iterator(batch_iterator(1_000), trainer=trainer, length=len(dataset)) 47 | tokenizer.save(str(tokenizer_path)) 48 | 49 | train_ds, val_ds = datasets.load_dataset(*_name, 50 | split=['train' if train_full else 'train[:20%]', 'validation']) 51 | 52 | def to_ids(sent): 53 | tokenized = tokenizer.encode(sent[_column_name]) 54 | return {"ids": tokenized.ids, "mask": tokenized.attention_mask} 55 | 56 | train_ds = train_ds.filter(lambda e: len(e[_column_name]) > 20).map(to_ids, num_proc=10) 57 | val_ds = val_ds.filter(lambda e: len(e[_column_name]) > 20).map(to_ids) 58 | train_ds.set_format(type='torch', columns=['ids', 'mask']) 59 | val_ds.set_format(type='torch', columns=['ids', 'mask']) 60 | 61 | return ( 62 | DataLoader(train_ds, batch_size=batch_size, shuffle=True, num_workers=num_workers, pin_memory=True), 63 | DataLoader(val_ds, batch_size=batch_size, num_workers=num_workers), 64 | tokenizer, 65 | tokenizer.get_vocab_size() 66 | ) 67 | -------------------------------------------------------------------------------- /models/gpt.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | from copy import deepcopy 3 | from typing import Optional, Tuple 4 | 5 | import torch 6 | from torch import nn 7 | 8 | from .attentions import SelfAttention 9 | from .base import TransformerBase 10 | from .blocks import BLOCK, BlockBase 11 | 12 | 13 | class MaskedSequential(nn.Sequential): 14 | def forward(self, 15 | input: torch.Tensor, 16 | mask: torch.Tensor 17 | ) -> torch.Tensor: 18 | for module in self: 19 | input = module(input, mask) 20 | return input 21 | 22 | 23 | class GPT(TransformerBase): 24 | def __init__(self, 25 | block: BlockBase, 26 | vocab_size: int, 27 | max_len: int, 28 | emb_dim: int, 29 | num_layers: int, 30 | emb_dropout_rate: float, 31 | enable_checkpoint: bool = False 32 | ): 33 | super().__init__(MaskedSequential(*[deepcopy(block) for _ in range(num_layers)]), enable_checkpoint) 34 | self.max_len = max_len 35 | self.num_layers = num_layers 36 | self.tok_emb = nn.Embedding(vocab_size, emb_dim) 37 | self.pos_emb = nn.Parameter(torch.zeros(1, max_len, emb_dim)) 38 | self.dropout = nn.Dropout(emb_dropout_rate) 39 | self.head = nn.Sequential(nn.LayerNorm(emb_dim), nn.Linear(emb_dim, vocab_size, bias=False)) 40 | self.register_buffer("mask", torch.tril(torch.ones(max_len, max_len, dtype=torch.bool))[None, None]) 41 | self.init_weights() 42 | 43 | def init_weights(self): 44 | for module in self.modules(): 45 | if isinstance(module, (nn.Linear, nn.Embedding)): 46 | nn.init.normal_(module.weight, 0, 0.02) 47 | if isinstance(module, nn.Linear) and module.bias is not None: 48 | nn.init.zeros_(module.bias) 49 | if isinstance(module, nn.LayerNorm): 50 | nn.init.zeros_(module.bias) 51 | nn.init.ones_(module.weight) 52 | 53 | def forward(self, 54 | input: torch.Tensor, 55 | mask: Optional[torch.Tensor] = None 56 | ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: 57 | b, t = input.size() 58 | _mask = self.mask[:, :, :t, :t] 59 | mask = _mask if mask is None else (_mask & mask[:, None, None, :].bool()) 60 | token_emb = self.tok_emb(input) 61 | pos_emb = self.pos_emb[:, :t, :] 62 | x = self.dropout(token_emb + pos_emb) # BxNxC 63 | x = self.blocks(x, mask) 64 | logits = self.head(x) # BxNxV 65 | return logits 66 | 67 | @classmethod 68 | def construct(cls, 69 | block: str, 70 | vocab_size: int, 71 | max_len: int, 72 | num_heads: int = 12, 73 | emb_dim: int = 768, 74 | num_layers: int = 12, 75 | emb_dropout_rate: float = 0.1, 76 | attn_dropout_rate: float = 0.1, 77 | proj_dropout_rate: float = 0.1, 78 | enable_checkpoint: bool = False, 79 | **kwargs 80 | ): 81 | if len(kwargs) > 0: 82 | warnings.warn(f"kwargs={kwargs} are not used") 83 | block = BLOCK(block)(emb_dim, 84 | SelfAttention(emb_dim, num_heads, attn_dropout_rate, proj_dropout_rate), 85 | proj_dropout_rate) 86 | return cls(block, vocab_size, max_len, emb_dim, num_layers, emb_dropout_rate, enable_checkpoint) 87 | -------------------------------------------------------------------------------- /models/attentions.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import functools 4 | import math 5 | from typing import Callable, Optional 6 | 7 | import torch 8 | from homura import Registry 9 | from torch import nn 10 | 11 | try: 12 | import opt_einsum 13 | 14 | print("opt_einsum is installed, so einsum=opt_einsum") 15 | 16 | einsum = functools.partial(opt_einsum.contract, backend="torch") 17 | 18 | except ImportError: 19 | print("no opt_einsum") 20 | 21 | einsum = torch.einsum 22 | 23 | ATTENTIONS = Registry("attentions", ) 24 | 25 | 26 | # helper functions 27 | def _masking(context: torch.Tensor, 28 | mask: torch.Tensor 29 | ) -> torch.Tensor: 30 | if mask is None: 31 | return context 32 | 33 | return context.masked_fill(mask == 0, float('-inf')) 34 | 35 | 36 | def _talking(context: torch.Tensor, 37 | talk_tensor: torch.Tensor 38 | ) -> torch.Tensor: 39 | if talk_tensor is None: 40 | return context 41 | 42 | return einsum("bhmn,hk->bkmn", context, talk_tensor) 43 | 44 | 45 | @ATTENTIONS.register(name="dotprod") 46 | def dotproduct_self_attention(query: torch.Tensor, 47 | key: torch.Tensor, 48 | value: torch.Tensor, 49 | mask: Optional[torch.Tensor] = None, 50 | dropout: Optional[Callable[[torch.Tensor], torch.Tensor]] = None, 51 | pre_talk: Optional[torch.Tensor] = None, 52 | post_talk: Optional[torch.Tensor] = None, 53 | ) -> torch.Tensor: 54 | """ dot-product self-attention 55 | 56 | Args: 57 | query: tensor of shape BHKN 58 | key: tensor of shape BHKM 59 | value: tensor of shape BHVN 60 | mask: optional mask 61 | dropout: optional dropout function 62 | pre_talk: optional tensor for talking attention 63 | post_talk: optional tensor for talking attention 64 | 65 | Returns: results 66 | 67 | """ 68 | 69 | # attn/\sqrt{dim_head} 70 | context = einsum("bhkn,bhkm->bhmn", query, key).div(math.sqrt(query.size(-2))) 71 | context = _talking(context, pre_talk) 72 | context = _masking(context, mask) 73 | context = context.softmax(dim=-1) 74 | context = _talking(context, post_talk) 75 | if dropout is not None: 76 | context = dropout(context) 77 | return einsum("bhmn,bhvm->bhvn", context, value) 78 | 79 | 80 | class SelfAttention(nn.Module): 81 | def __init__(self, 82 | emb_dim: int, 83 | num_heads: int, 84 | attn_dropout_rate: float, 85 | proj_dropout_rate: float, 86 | qkv_bias: bool = True, 87 | proj_bias: bool = True, 88 | talking_heads: bool = False 89 | ): 90 | super().__init__() 91 | 92 | self.emb_dim = emb_dim 93 | self.num_heads = num_heads 94 | self.key = nn.Linear(emb_dim, emb_dim, bias=qkv_bias) 95 | self._query = nn.Linear(emb_dim, emb_dim, bias=qkv_bias) 96 | self.value = nn.Linear(emb_dim, emb_dim, bias=qkv_bias) 97 | self.proj = nn.Linear(emb_dim, emb_dim, bias=proj_bias) 98 | self.attn_dropout = nn.Dropout(attn_dropout_rate) 99 | self.proj_dropout = nn.Dropout(proj_dropout_rate) 100 | self.pre_talk, self.post_talk = None, None 101 | if talking_heads: 102 | self.pre_talk = nn.Parameter(torch.randn(self.num_heads, self.num_heads)) 103 | self.post_talk = nn.Parameter(torch.randn(self.num_heads, self.num_heads)) 104 | 105 | def query(self, 106 | input: torch.Tensor 107 | ) -> torch.Tensor: 108 | b = input.size(0) 109 | return self._query(input).transpose(-1, -2).view(b, self.num_heads, self.emb_dim // self.num_heads, -1) 110 | 111 | def forward(self, 112 | input: torch.Tensor, 113 | mask: Optional[torch.Tensor] = None 114 | ) -> torch.Tensor: 115 | # input: BxNxC 116 | b = input.size(0) 117 | # BxNxC -> BxCxN -> BxHxC'xN 118 | query = self.query(input) 119 | key = self.key(input).transpose(-1, -2).view(b, self.num_heads, self.emb_dim // self.num_heads, -1) 120 | value = self.value(input).transpose(-1, -2).view(b, self.num_heads, self.emb_dim // self.num_heads, -1) 121 | attention = dotproduct_self_attention(query, key, value, mask, self.attn_dropout, self.pre_talk, self.post_talk) 122 | attention = attention.reshape(b, self.emb_dim, -1).transpose(-1, -2) 123 | return self.proj_dropout(self.proj(attention)) 124 | 125 | 126 | class ClassAttention(SelfAttention): 127 | def query(self, 128 | input: torch.Tensor 129 | ) -> torch.Tensor: 130 | # BxC->BxHxC'x1 131 | return self._query(input[:, 0]).view(input.size(0), self.num_heads, self.emb_dim // self.num_heads, 1) 132 | -------------------------------------------------------------------------------- /vit.py: -------------------------------------------------------------------------------- 1 | import chika 2 | import homura 3 | import torch 4 | from homura import lr_scheduler, reporters 5 | from homura.modules import SmoothedCrossEntropy 6 | from homura.trainers import SupervisedTrainer 7 | from homura.vision.data import DATASET_REGISTRY 8 | from torchvision.transforms import AutoAugment, RandomErasing 9 | 10 | from models.vit import ViTEMA, ViTs 11 | from utils import distributed_ready_main 12 | from vision_utils import fast_collate, gen_mixup_collate 13 | 14 | 15 | class ViTTraner(SupervisedTrainer): 16 | def __init__(self, *args, **kwargs): 17 | self.optim_cfg = kwargs.pop('optim_cfg') 18 | super().__init__(*args, **kwargs) 19 | 20 | def set_optimizer(self 21 | ) -> None: 22 | params_dict = self.accessible_model.param_groups 23 | optim_groups = [ 24 | {"params": params_dict['decay'], "weight_decay": self.optim_cfg.weight_decay}, 25 | {"params": params_dict['no_decay'], "weight_decay": 0} 26 | ] 27 | self.optimizer = torch.optim._multi_tensor.AdamW(optim_groups, 28 | lr=self.optim_cfg.lr, 29 | weight_decay=self.optim_cfg.weight_decay) 30 | self.logger.debug(self.optimizer) 31 | 32 | 33 | @chika.config 34 | class DataConfig: 35 | batch_size: int = 128 36 | autoaugment: bool = False 37 | random_erasing: bool = False 38 | mixup: float = 0 39 | 40 | 41 | @chika.config 42 | class ModelConfig: 43 | name: str = chika.choices(*ViTs.choices()) 44 | dropout_rate: float = 0 45 | droppath_rate: float = 0 46 | ema: bool = False 47 | ema_rate: float = chika.bounded(0.999, 0, 1) 48 | 49 | 50 | @chika.config 51 | class OptimConfig: 52 | lr: float = 5e-4 53 | weight_decay: float = 0.05 54 | label_smoothing: float = 0.1 55 | epochs: int = 200 56 | min_lr: float = 1e-5 57 | warmup_epochs: int = 5 58 | multiplier: int = 1 59 | 60 | 61 | @chika.config 62 | class Config: 63 | data: DataConfig 64 | model: ModelConfig 65 | optim: OptimConfig 66 | 67 | debug: bool = False 68 | amp: bool = False 69 | gpu: int = None 70 | no_save: bool = False 71 | 72 | def __post_init__(self): 73 | assert self.optim.lr > self.optim.min_lr 74 | self.optim.lr *= self.data.batch_size * homura.get_world_size() / 512 75 | self.optim.min_lr *= self.data.batch_size * homura.get_world_size() / 512 76 | 77 | 78 | @chika.main(cfg_cls=Config, change_job_dir=True) 79 | @distributed_ready_main 80 | def main(cfg: Config): 81 | if cfg.gpu is not None: 82 | torch.cuda.set_device(cfg.gpu) 83 | if homura.is_master(): 84 | import rich 85 | rich.print(cfg) 86 | vs = DATASET_REGISTRY("imagenet") 87 | vs.collate_fn = fast_collate if cfg.data.mixup == 0 else gen_mixup_collate(cfg.data.mixup) 88 | model = ViTs(cfg.model.name)(droppath_rate=cfg.model.droppath_rate, dropout_rate=cfg.model.dropout_rate) 89 | train_da = vs.default_train_da.copy() 90 | test_da = vs.default_test_da.copy() 91 | train_da[0].size = model.image_size 92 | test_da[0].size = model.image_size 93 | test_da[1].size = model.image_size 94 | if cfg.data.autoaugment: 95 | train_da.append(AutoAugment()) 96 | post_da = [RandomErasing()] if cfg.data.random_erasing else None 97 | train_loader, test_loader = vs(batch_size=cfg.data.batch_size, 98 | train_da=train_da, 99 | test_da=test_da, 100 | post_norm_train_da=post_da, 101 | train_size=cfg.data.batch_size * 50 if cfg.debug else None, 102 | test_size=cfg.data.batch_size * 50 if cfg.debug else None, 103 | num_workers=8) 104 | if cfg.model.ema: 105 | model = ViTEMA(model, cfg.model.ema_rate) 106 | scheduler = lr_scheduler.CosineAnnealingWithWarmup(cfg.optim.epochs, multiplier=cfg.optim.multiplier, 107 | warmup_epochs=cfg.optim.warmup_epochs, 108 | min_lr=cfg.optim.min_lr) 109 | 110 | with ViTTraner(model, 111 | None, 112 | SmoothedCrossEntropy(cfg.optim.label_smoothing), 113 | reporters=[reporters.TensorboardReporter(".")], 114 | scheduler=scheduler, 115 | use_amp=cfg.amp, 116 | use_cuda_nonblocking=True, 117 | report_accuracy_topk=5, 118 | optim_cfg=cfg.optim, 119 | debug=cfg.debug 120 | ) as trainer: 121 | for ep in trainer.epoch_range(cfg.optim.epochs): 122 | trainer.train(train_loader) 123 | trainer.test(test_loader) 124 | trainer.scheduler.step() 125 | if not cfg.no_save: 126 | trainer.save(f"outputs/{cfg.model.name}", f"{ep}") 127 | 128 | print(f"Max Test Accuracy={max(trainer.reporter.history('accuracy/test')):.3f}") 129 | 130 | 131 | if __name__ == '__main__': 132 | import warnings 133 | 134 | # to suppress annoying warnings from PIL 135 | warnings.filterwarnings("ignore", "(Possibly )?corrupt EXIF data", UserWarning) 136 | 137 | main() 138 | -------------------------------------------------------------------------------- /models/blocks.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Type 2 | 3 | import torch 4 | from homura import Registry 5 | from torch import nn 6 | 7 | from .attentions import SelfAttention 8 | 9 | BLOCK = Registry("block", nn.Module) 10 | 11 | 12 | def act_func(name: str 13 | ) -> nn.Module: 14 | _acts = {"relu": nn.ReLU, 15 | "leaky_relu": nn.LeakyReLU, 16 | "gelu": nn.GELU, 17 | "silu": nn.SiLU} 18 | return _acts[name]() 19 | 20 | 21 | class BlockBase(nn.Module): 22 | def __init__(self, 23 | emb_dim: int, 24 | attention: SelfAttention, 25 | dropout_rate: float, 26 | widen_factor: int = 4, 27 | activation: str = "gelu", 28 | norm: Type[nn.LayerNorm] = nn.LayerNorm): 29 | super().__init__() 30 | self.ln1 = norm(emb_dim) 31 | self.ln2 = norm(emb_dim) 32 | self.attention = attention 33 | self.mlp = nn.Sequential(nn.Linear(emb_dim, widen_factor * emb_dim), 34 | act_func(activation), 35 | nn.Linear(widen_factor * emb_dim, emb_dim), 36 | nn.Dropout(dropout_rate)) 37 | 38 | def forward(self, 39 | input: torch.Tensor, 40 | mask: Optional[torch.Tensor] = None 41 | ) -> torch.Tensor: 42 | raise NotImplementedError 43 | 44 | 45 | @BLOCK.register(name="post_ln") 46 | class PostLNBlock(BlockBase): 47 | """ Transformer Block from "Attention is All You Need" 48 | """ 49 | 50 | def forward(self, 51 | input: torch.Tensor, 52 | mask: Optional[torch.Tensor] = None 53 | ) -> torch.Tensor: 54 | x = input 55 | x = x + self.attention(x, mask) 56 | x = self.ln1(x) 57 | x = x + self.mlp(x) 58 | return self.ln2(x) 59 | 60 | 61 | @BLOCK.register(name="pre_ln") 62 | class PreLNBlock(BlockBase): 63 | """ BERT's Transformer Block 64 | """ 65 | 66 | def forward(self, 67 | input: torch.Tensor, 68 | mask: Optional[torch.Tensor] = None 69 | ) -> torch.Tensor: 70 | x = input 71 | x = self.ln1(x) 72 | x = x + self.attention(x, mask) 73 | x = self.ln2(x) 74 | return x + self.mlp(x) 75 | 76 | 77 | @BLOCK.register(name="ipre_ln") 78 | class ImprovedPreLNBlock(BlockBase): 79 | """ Megatron-LM's Transformer Block 80 | """ 81 | 82 | def forward(self, 83 | input: torch.Tensor, 84 | mask: Optional[torch.Tensor] = None 85 | ) -> torch.Tensor: 86 | x = input 87 | x = x + self.attention(self.ln1(x), mask) 88 | return x + self.mlp(self.ln2(x)) 89 | 90 | 91 | class TimmPreLNBlock(BlockBase): 92 | # Transformer Block used in timm 93 | def __init__(self, 94 | emb_dim: int, 95 | attention: SelfAttention, 96 | dropout_rate: float, 97 | droppath_rate: float, 98 | widen_factor: int, 99 | activation: str, 100 | norm: Type[nn.LayerNorm]): 101 | super().__init__(emb_dim, attention, dropout_rate, widen_factor, activation, norm) 102 | # double dropout 103 | self.mlp = nn.Sequential(nn.Linear(emb_dim, widen_factor * emb_dim), 104 | act_func(activation), 105 | nn.Dropout(dropout_rate), 106 | nn.Linear(widen_factor * emb_dim, emb_dim), 107 | nn.Dropout(dropout_rate)) 108 | self.droppath_rate = droppath_rate 109 | self.emb_dim = emb_dim 110 | 111 | def forward(self, 112 | input: torch.Tensor, 113 | mask=None 114 | ) -> torch.Tensor: 115 | x = input 116 | x = x + self.drop_path(self.attention(self.ln1(x))) 117 | x = x + self.drop_path(self.mlp(self.ln2(x))) 118 | return x 119 | 120 | def drop_path(self, 121 | input: torch.Tensor, 122 | ) -> torch.Tensor: 123 | if not self.training or self.droppath_rate == 0: 124 | return input 125 | 126 | keep_prob = 1 - self.droppath_rate 127 | # 1 with prob. of keep_prob 128 | drop = input.new_empty(input.size(0), 1, 1).bernoulli_(keep_prob) 129 | return input.div(keep_prob).mul(drop) 130 | 131 | 132 | class LayerScaleBlock(TimmPreLNBlock): 133 | # Transformer Block with Layer Scale 134 | def __init__(self, 135 | emb_dim: int, 136 | attention: SelfAttention, 137 | dropout_rate: float, 138 | droppath_rate: float, 139 | widen_factor: int, 140 | activation: str, 141 | norm: Type[nn.LayerNorm], 142 | init_scale: float): 143 | super().__init__(emb_dim, attention, dropout_rate, droppath_rate, widen_factor, activation, norm) 144 | self.gamma1 = nn.Parameter(init_scale * torch.ones(emb_dim)) 145 | self.gamma2 = nn.Parameter(init_scale * torch.ones(emb_dim)) 146 | 147 | def forward(self, 148 | input: torch.Tensor, 149 | cls_token: Optional[torch.Tensor] = None 150 | ) -> torch.Tensor: 151 | if cls_token is None: 152 | x = input 153 | else: 154 | x = cls_token 155 | input = torch.cat((cls_token, input), dim=1) 156 | x = x + self.drop_path(self.gamma1 * self.attention(self.ln1(input))) 157 | x = x + self.drop_path(self.gamma2 * self.mlp(self.ln2(x))) 158 | return x 159 | -------------------------------------------------------------------------------- /gpt.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from typing import Dict, Optional, Tuple 4 | 5 | import chika 6 | import homura 7 | import torch 8 | from homura import TensorTuple 9 | from homura.trainers import SupervisedTrainer 10 | from torch.nn import functional as F 11 | 12 | from models.gpt import GPT 13 | from nlp_utils import get_data 14 | 15 | 16 | class GPTTrainer(SupervisedTrainer): 17 | def __init__(self, *args, **kwargs): 18 | self.cfg = kwargs.pop('cfg') 19 | self.optim_cfg = kwargs.pop('optim_cfg') 20 | super().__init__(*args, **kwargs) 21 | 22 | def set_optimizer(self 23 | ) -> None: 24 | params_dict = self.model.param_groups 25 | optim_groups = [ 26 | {"params": params_dict['decay'], "weight_decay": self.optim_cfg.weight_decay}, 27 | {"params": params_dict['no_decay'], "weight_decay": 0} 28 | ] 29 | base = torch.optim._multi_tensor if self.optim_cfg.multi_tensor else torch.optim 30 | cls = getattr(base, "AdamW" if self.optim_cfg.name == "adamw" else "Adam") 31 | self.optimizer = cls(optim_groups, lr=self.optim_cfg.lr, betas=self.optim_cfg.betas) 32 | self.logger.debug(self.optimizer) 33 | 34 | def _loop(self, 35 | data_loader, 36 | mode: str 37 | ) -> None: 38 | 39 | self.inner_tqdm = self._tqdm(data_loader) 40 | for data in self.inner_tqdm: 41 | if self.is_train: 42 | # increment step here for `callbacks` 43 | self._step += 1 44 | self._iteration(data, mode) 45 | 46 | self.reporter.report(self.epoch, mode) 47 | self.logger.debug(f"epoch {self.epoch} finished") 48 | 49 | def data_preprocess(self, 50 | data: Dict[str, torch.Tensor] 51 | ) -> Tuple[torch.Tensor, int]: 52 | ids, mask = data['ids'], data['mask'] 53 | return TensorTuple((ids, mask)).to(self.device, non_blocking=self._cuda_nonblocking), ids.size(0) 54 | 55 | def iteration(self, 56 | data: torch.Tensor 57 | ) -> None: 58 | ids, mask = data 59 | input, target = ids[:, :-1], ids[:, 1:] 60 | ignore_index = -100 61 | target = target.masked_fill(mask[:, 1:] == 0, ignore_index) 62 | with torch.cuda.amp.autocast(self._use_amp): 63 | logits = self.model(input, mask[:, 1:]) 64 | loss = F.cross_entropy(logits.flatten(0, -2), target.reshape(-1), ignore_index=ignore_index) 65 | self.reporter.add("loss", loss.detach()) 66 | if self.is_train: 67 | self.optimizer.zero_grad(set_to_none=True) 68 | if self._use_amp: 69 | self.scaler.scale(loss).backward() 70 | else: 71 | loss.backward() 72 | if self.cfg.grad_norm_clip > 0: 73 | if self._use_amp: 74 | self.scaler.unscale_(self.optimizer) 75 | torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.cfg.grad_norm_clip) 76 | if self._use_amp: 77 | self.scaler.step(self.optimizer) 78 | self.scaler.update() 79 | else: 80 | self.optimizer.step() 81 | self.scheduler.step() 82 | if self.step % 500 == 0: 83 | self.inner_tqdm.set_postfix({"loss": f"{loss.cpu().item():.3e}"}) 84 | 85 | @torch.no_grad() 86 | def sample(self, 87 | x: torch.Tensor, 88 | num_steps: int, 89 | temperature: float = 1.0, 90 | sampling: bool = False, 91 | only_tok_k: Optional[int] = None 92 | ) -> torch.Tensor: 93 | x = x.clone() 94 | self.model.eval() 95 | max_len = self.model.max_len 96 | for k in range(num_steps): 97 | x_cond = x if x.size(1) <= max_len else x[:, -max_len:] 98 | logits = self.model(x_cond) 99 | logits = logits[:, -1, :] / temperature 100 | 101 | if only_tok_k is not None: 102 | val, idx = logits.topk(k=only_tok_k) 103 | logits[logits < val[:, [-1]]] = float('-inf') 104 | 105 | probs = logits.softmax(dim=-1) 106 | 107 | if sampling: 108 | next = torch.multinomial(probs, num_samples=1) 109 | else: 110 | next = probs.argmax(dim=-1) 111 | x = torch.cat([x, next], dim=1) 112 | return x 113 | 114 | 115 | @chika.config 116 | class DataConfig: 117 | name: str = chika.choices("wikitext", "gigaword") 118 | batch_size: int = 64 119 | max_len: int = 150 120 | train_full: bool = False 121 | 122 | 123 | @chika.config 124 | class OptimConfig: 125 | epochs: int = 20 126 | name: str = chika.choices("adamw", "adam") 127 | lr: float = 2e-4 128 | weight_decay: float = 0.1 129 | betas: Tuple[float] = chika.sequence(0.9, 0.98) 130 | warmup_iters: int = 1_000 131 | multi_tensor: bool = False 132 | 133 | 134 | @chika.config 135 | class ModelConfig: 136 | block: str = chika.choices("ipre_ln", "pre_ln", "post_ln") 137 | grad_norm_clip: float = 1.0 138 | 139 | num_heads: int = 8 140 | emb_dim: int = 768 141 | num_layers: int = 12 142 | emb_dropout_rate: float = 0.1 143 | attn_dropout_rate: float = 0.1 144 | proj_dropout_rate: float = 0.1 145 | 146 | enable_checkpoint: bool = False 147 | 148 | 149 | @chika.config 150 | class Config: 151 | model: ModelConfig 152 | optim: OptimConfig 153 | data: DataConfig 154 | seed: int = 1 155 | gpu: int = 0 156 | amp: bool = False 157 | 158 | 159 | @chika.main(cfg_cls=Config, strict=True) 160 | def main(cfg: Config): 161 | print(cfg) 162 | torch.cuda.set_device(cfg.gpu) 163 | homura.set_seed(cfg.seed) 164 | train_loader, val_loader, tokenizer, vocab_size = get_data(**cfg.data.to_dict()) 165 | model = GPT.construct(**cfg.model.to_dict(), vocab_size=vocab_size, max_len=cfg.data.max_len) 166 | # optimizer is setup automatically 167 | scheduler = homura.lr_scheduler.CosineAnnealingWithWarmup(cfg.optim.epochs * len(train_loader), 1, 168 | cfg.optim.warmup_iters) 169 | sample_text = tokenizer.encode("however, as can be seen from") 170 | # sample_text = tokenizer.encode("in the beginning was the word") 171 | sample_tensor = torch.tensor(sample_text.ids[:sum(sample_text.attention_mask)]).view(1, -1) 172 | with GPTTrainer(model, None, None, 173 | reporters=[homura.reporters.TensorboardReporter(".")], 174 | scheduler=scheduler, 175 | cfg=cfg.model, 176 | optim_cfg=cfg.optim, 177 | use_amp=cfg.amp 178 | ) as trainer: 179 | for ep in trainer.epoch_range(cfg.optim.epochs): 180 | trainer.train(train_loader) 181 | trainer.test(val_loader, "val") 182 | sampled = trainer.sample(sample_tensor.to(trainer.device), num_steps=64, sampling=True, only_tok_k=100) 183 | sampled_text = tokenizer.decode(sampled.view(-1).cpu().tolist(), False) 184 | print(f"[{ep:>4}] train loss = {trainer.history['loss/train'][-1]:.3e}" 185 | f" val loss={trainer.history['loss/val'][-1]:.3e}|| {sampled_text}") 186 | trainer.save("outputs", f"{ep}") 187 | 188 | 189 | if __name__ == "__main__": 190 | import warnings 191 | 192 | # to avoid "Detected call of `lr_scheduler.step()` before `optimizer.step()`... when using AMP 193 | warnings.filterwarnings("ignore", message="Detected call of") 194 | main() 195 | -------------------------------------------------------------------------------- /models/vit.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import math 4 | from copy import deepcopy 5 | from functools import partial 6 | from typing import Type 7 | 8 | import torch 9 | from homura import Registry 10 | from homura.modules import EMA 11 | from torch import nn 12 | 13 | from .attentions import ClassAttention, SelfAttention 14 | from .base import TransformerBase 15 | from .blocks import LayerScaleBlock, TimmPreLNBlock 16 | from .embeddings import PatchEmbed2d 17 | 18 | ViTs = Registry("vit", nn.Module) 19 | 20 | 21 | class ViT(TransformerBase): 22 | """ Vision Transformer 23 | """ 24 | 25 | def __init__(self, 26 | attention: SelfAttention, 27 | num_classes: int, 28 | image_size: int or tuple, 29 | patch_size: int or tuple, 30 | emb_dim: int, 31 | num_layers: int, 32 | emb_dropout_rate: float, 33 | dropout_rate: float, 34 | droppath_rate: float, 35 | in_channels: int, 36 | norm: Type[nn.LayerNorm] = nn.LayerNorm, 37 | mlp_widen_factor: int = 4, 38 | activation: str = "gelu", 39 | enable_checkpointing=False 40 | ): 41 | blocks = [TimmPreLNBlock(emb_dim, deepcopy(attention), dropout_rate=dropout_rate, droppath_rate=r, 42 | widen_factor=mlp_widen_factor, norm=norm, activation=activation) 43 | for r in [x.item() for x in torch.linspace(0, droppath_rate, num_layers)] 44 | ] 45 | super().__init__(nn.Sequential(*blocks), enable_checkpointing) 46 | image_size = (image_size, image_size) if isinstance(image_size, int) else image_size 47 | patch_size = (patch_size, patch_size) if isinstance(patch_size, int) else patch_size 48 | num_patches = math.prod(image_size) // math.prod(patch_size) 49 | 50 | self.image_size = image_size 51 | self.patch_size = patch_size 52 | 53 | self.patch_emb = PatchEmbed2d(patch_size, emb_dim, in_channels) 54 | self.pos_emb = nn.Parameter(torch.zeros(1, num_patches + 1, emb_dim)) 55 | self.cls_token = nn.Parameter(torch.zeros(1, 1, emb_dim)) 56 | self.dropout = nn.Dropout(emb_dropout_rate) 57 | self.norm = norm(emb_dim) 58 | self.fc = nn.Linear(emb_dim, num_classes) 59 | self.init_weights() 60 | 61 | def forward(self, 62 | input: torch.Tensor 63 | ) -> torch.Tensor: 64 | x = self.patch_emb(input) # BxNxC 65 | cls_token = self.cls_token.expand(x.size(0), -1, -1) 66 | x = torch.cat((cls_token, x), dim=1) # Bx(N+1)xC 67 | x = self.dropout(self.pos_emb + x) 68 | x = self.norm(self.blocks(x)) 69 | return self.fc(x[:, 0]) 70 | 71 | def init_weights(self): 72 | for module in self.modules(): 73 | if isinstance(module, nn.Linear): 74 | nn.init.trunc_normal_(module.weight, std=0.02) 75 | if module.bias is not None: 76 | nn.init.normal_(module.bias, 1e-6) 77 | if isinstance(module, nn.LayerNorm): 78 | nn.init.ones_(module.weight) 79 | nn.init.zeros_(module.bias) 80 | 81 | nn.init.zeros_(self.fc.weight) 82 | nn.init.zeros_(self.fc.bias) 83 | proj_w = self.patch_emb.proj 84 | fan_in = proj_w.in_channels * math.prod(proj_w.kernel_size) 85 | nn.init.trunc_normal_(proj_w, std=math.sqrt(1 / fan_in)) 86 | nn.init.zeros_(self.patch_emb.proj.bias) 87 | nn.init.trunc_normal_(self.pos_emb, std=0.02) 88 | nn.init.trunc_normal_(self.cls_token, std=0.02) 89 | 90 | @classmethod 91 | def construct(cls, 92 | emb_dim: int, 93 | num_layers: int, 94 | num_heads: int, 95 | patch_size: int, 96 | dropout_rate: float = 0, 97 | attn_dropout_rate: float = 0, 98 | droppath_rate: float = 0, 99 | num_classes: int = 1_000, 100 | image_size: int = 224, 101 | in_channels: int = 3, 102 | layernorm_eps: float = 1e-6, 103 | activation: str = "gelu", 104 | **kwargs 105 | ) -> ViT: 106 | attention = SelfAttention(emb_dim, num_heads, attn_dropout_rate, dropout_rate) 107 | return cls(attention, num_classes, image_size, patch_size, emb_dim, num_layers, 108 | dropout_rate, dropout_rate, droppath_rate, in_channels=in_channels, 109 | norm=partial(nn.LayerNorm, eps=layernorm_eps), activation=activation) 110 | 111 | 112 | class ViTEMA(EMA): 113 | def __init__(self, *args, **kwargs): 114 | super().__init__(*args, **kwargs) 115 | self.ema_model.eval() 116 | 117 | @property 118 | def param_groups(self): 119 | return self.original_model.param_groups 120 | 121 | 122 | @ViTs.register 123 | def vit_t16(**kwargs) -> ViT: 124 | return ViT.construct(192, 12, 3, 16, **kwargs) 125 | 126 | 127 | @ViTs.register 128 | def vit_t16_384(**kwargs) -> ViT: 129 | return ViT.construct(192, 12, 3, 16, image_size=384, **kwargs) 130 | 131 | 132 | @ViTs.register 133 | def vit_b16(**kwargs) -> ViT: 134 | return ViT.construct(768, 12, 12, 16, **kwargs) 135 | 136 | 137 | @ViTs.register 138 | def vit_b16_384(**kwargs) -> ViT: 139 | return ViT.construct(768, 12, 12, 16, image_size=384, **kwargs) 140 | 141 | 142 | @ViTs.register 143 | def vit_b32(**kwargs) -> ViT: 144 | return ViT.construct(768, 12, 12, 32, **kwargs) 145 | 146 | 147 | @ViTs.register 148 | def vit_l16(**kwargs) -> ViT: 149 | return ViT.construct(1024, 24, 16, 16, **kwargs) 150 | 151 | 152 | @ViTs.register 153 | def vit_l32(**kwargs) -> ViT: 154 | return ViT.construct(1024, 24, 16, 32, **kwargs) 155 | 156 | 157 | class CaiTSequential(nn.Sequential): 158 | # a helper module for CaiT to use gradient checkpointing 159 | def forward(self, 160 | input: torch.Tensor, 161 | cls_token: torch.Tensor 162 | ) -> tuple[torch.Tensor, torch.Tensor]: 163 | for module in self: 164 | if isinstance(module.attention, ClassAttention): 165 | cls_token = module(input, cls_token) 166 | else: 167 | input = module(input) 168 | return input, cls_token 169 | 170 | 171 | class CaiT(TransformerBase): 172 | """ CaiT from Touvron+2021 Going deeper with Image Transformers. https://github.com/facebookresearch/deit 173 | """ 174 | 175 | def __init__(self, 176 | attention: SelfAttention, 177 | cls_attention: ClassAttention, 178 | num_classes: int, 179 | image_size: int or tuple, 180 | patch_size: int or tuple, 181 | emb_dim: int, 182 | num_layers: int, 183 | num_cls_layers: int, 184 | emb_dropout_rate: float, 185 | dropout_rate: float, 186 | droppath_rate: float, 187 | in_channels: int, 188 | norm: Type[nn.LayerNorm] = nn.LayerNorm, 189 | mlp_widen_factor: int = 4, 190 | activation: str = "gelu", 191 | enable_checkpointing=False, 192 | init_scale: float = 1e-5 193 | ): 194 | blocks1 = [LayerScaleBlock(emb_dim, deepcopy(attention), dropout_rate=dropout_rate, droppath_rate=droppath_rate, 195 | widen_factor=mlp_widen_factor, norm=norm, activation=activation, 196 | init_scale=init_scale) 197 | for _ in range(num_layers)] 198 | blocks2 = [LayerScaleBlock(emb_dim, deepcopy(cls_attention), dropout_rate=0, droppath_rate=0, 199 | widen_factor=mlp_widen_factor, norm=norm, activation=activation, 200 | init_scale=init_scale) 201 | for _ in range(num_cls_layers)] 202 | blocks = blocks1 + blocks2 203 | super().__init__(CaiTSequential(*blocks), enable_checkpointing) 204 | image_size = (image_size, image_size) if isinstance(image_size, int) else image_size 205 | patch_size = (patch_size, patch_size) if isinstance(patch_size, int) else patch_size 206 | num_patches = math.prod(image_size) // math.prod(patch_size) 207 | 208 | self.image_size = image_size 209 | self.patch_size = patch_size 210 | 211 | self.patch_emb = PatchEmbed2d(patch_size, emb_dim, in_channels) 212 | self.pos_emb = nn.Parameter(torch.zeros(1, num_patches, emb_dim)) 213 | self.cls_token = nn.Parameter(torch.zeros(1, 1, emb_dim)) 214 | self.dropout = nn.Dropout(emb_dropout_rate) 215 | self.norm = norm(emb_dim) 216 | self.fc = nn.Linear(emb_dim, num_classes) 217 | self.init_weights() 218 | 219 | def forward(self, 220 | input: torch.Tensor 221 | ) -> torch.Tensor: 222 | x = self.patch_emb(input) # BxNxC 223 | cls_token = self.cls_token.expand(x.size(0), -1, -1) 224 | x = self.dropout(self.pos_emb + x) 225 | x = torch.cat(self.blocks(x, cls_token), dim=1) 226 | x = self.norm(x) 227 | return self.fc(x[:, 0]) 228 | 229 | def init_weights(self): 230 | for module in self.modules(): 231 | if isinstance(module, nn.Linear): 232 | nn.init.trunc_normal_(module.weight, std=0.02) 233 | if module.bias is not None: 234 | nn.init.zeros_(module.bias) 235 | if isinstance(module, nn.LayerNorm): 236 | nn.init.ones_(module.weight) 237 | nn.init.zeros_(module.bias) 238 | 239 | for name, param in self.named_parameters(): 240 | if "talk" in name: 241 | nn.init.trunc_normal_(param, std=0.02) 242 | 243 | nn.init.zeros_(self.fc.weight) 244 | nn.init.zeros_(self.fc.bias) 245 | nn.init.trunc_normal_(self.patch_emb.proj.weight) 246 | nn.init.zeros_(self.patch_emb.proj.bias) 247 | nn.init.trunc_normal_(self.pos_emb, std=0.02) 248 | nn.init.trunc_normal_(self.cls_token, std=0.02) 249 | 250 | @classmethod 251 | def construct(cls, 252 | emb_dim: int, 253 | num_layers: int, 254 | num_cls_layers: int, 255 | num_heads: int, 256 | patch_size: int, 257 | dropout_rate: float = 0, 258 | attn_dropout_rate: float = 0, 259 | droppath_rate: float = 0, 260 | num_classes: int = 1_000, 261 | image_size: int = 224, 262 | in_channels: int = 3, 263 | layernorm_eps: float = 1e-6, 264 | activation: str = "gelu", 265 | **kwargs 266 | ) -> CaiT: 267 | attention = SelfAttention(emb_dim, num_heads, attn_dropout_rate, dropout_rate, talking_heads=True) 268 | cls_attention = ClassAttention(emb_dim, num_heads, attn_dropout_rate, dropout_rate) 269 | return cls(attention, cls_attention, num_classes, image_size, patch_size, emb_dim, num_layers, num_cls_layers, 270 | dropout_rate, dropout_rate, droppath_rate, in_channels=in_channels, 271 | norm=partial(nn.LayerNorm, eps=layernorm_eps), activation=activation, **kwargs) 272 | 273 | 274 | @ViTs.register 275 | def cait_xs24(**kwargs): 276 | return CaiT.construct(emb_dim=288, num_layers=24, num_cls_layers=2, num_heads=6, patch_size=16, **kwargs) 277 | 278 | 279 | @ViTs.register 280 | def cait_s24_224(**kwargs): 281 | return CaiT.construct(emb_dim=384, num_layers=24, num_cls_layers=2, num_heads=8, patch_size=16, **kwargs) 282 | 283 | 284 | @ViTs.register 285 | def cait_s24_384(**kwargs): 286 | return CaiT.construct(emb_dim=384, num_layers=24, num_cls_layers=2, num_heads=8, patch_size=16, image_size=384, 287 | **kwargs) 288 | 289 | 290 | @ViTs.register 291 | def cait_m36_384(**kwargs): 292 | return CaiT.construct(emb_dim=384, num_layers=36, num_cls_layers=2, num_heads=8, patch_size=16, image_size=384, 293 | init_scale=1e-6, **kwargs) 294 | 295 | 296 | ViTs.register(cait_s24_384, name="cait_s24") 297 | ViTs.register(cait_m36_384, name="cait_m36") 298 | --------------------------------------------------------------------------------