├── .gitignore ├── LICENSE ├── README.md ├── __init__.py ├── alt_activations ├── __init__.py ├── cifar_batch.py ├── gpt.py ├── models │ ├── __init__.py │ └── resnet.py └── vit.py ├── attention_supervision ├── cifar_batch.py └── train_gpt.py ├── bert_glue.py ├── bert_mlm.py ├── boneless_attn ├── gpt.py └── vit.py ├── calibrated_attention ├── attentions.py └── gpt.py ├── common ├── __init__.py ├── activations.py ├── cifar_utils.py ├── gpt.py ├── randomaug.py └── vit.py ├── dipole_attn ├── __init__.py ├── bert_glue.py ├── bert_mlm.py ├── gpt.py ├── t5_patch.py ├── t5_translation.py └── train_cifar10.py ├── hf_gpt_blocks └── hf_gpt.py ├── mlp_mods ├── __init__.py ├── gpt.py └── vit.py ├── multi_token_pred ├── __init__.py └── gpt.py ├── qknorm_half ├── __init__.py ├── gpt.py ├── utils.py └── vit.py ├── relative_optimizers ├── caution_muon_adam.py ├── gpt.py ├── muon.py ├── relative_adam.py ├── relative_muon.py ├── relative_muon_2.py └── vit.py ├── residual_stream_scale ├── __init__.py ├── something.ipynb └── something.py ├── sam_optimizers ├── adam_two_momentum_perturb.py ├── adam_wd_perturb.py ├── gpt.py ├── muon.py ├── muon_adam_perturb.py ├── nesterov_perturb.py └── utils.py ├── sparsemax_attn ├── __init__.py ├── bert │ ├── patch.py │ ├── run_glue_no_trainer.py │ ├── run_mlm_no_trainer.py │ ├── run_ner_no_trainer.py │ ├── run_qa_no_trainer.py │ └── run_swag_no_trainer.py └── gpt.py ├── structured_transformer ├── __init__.py ├── common.py ├── gpt.py ├── train_cifar10.py ├── train_gpt.py └── vit.py ├── t5_translation.py ├── train_cifar10.py └── train_gpt.py /.gitignore: -------------------------------------------------------------------------------- 1 | *.cpython* 2 | *_output* 3 | *wandb* 4 | 5 | *.log 6 | *.json 7 | *wandb* 8 | *.safetensors* 9 | *ckpt* 10 | data 11 | log 12 | model-output 13 | __pycache__ 14 | checkpoint 15 | None -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # TransformerExperiments 2 | 3 | a buncha different experiments with transformers and attention 4 | 5 | 6 | ## Dipole Attention 7 | Explainer post: https://www.ethansmith2000.com/post/dipole-attention-opposites-may-be-deep-connections 8 | 9 | 10 | 11 | -------------------------------------------------------------------------------- /__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ethansmith2000/TransformerExperiments/eb77920ba61e42b87c2090fb01b2c94183cfe43b/__init__.py -------------------------------------------------------------------------------- /alt_activations/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ethansmith2000/TransformerExperiments/eb77920ba61e42b87c2090fb01b2c94183cfe43b/alt_activations/__init__.py -------------------------------------------------------------------------------- /alt_activations/cifar_batch.py: -------------------------------------------------------------------------------- 1 | 2 | from train_cifar10 import train_model, default_args 3 | from copy import deepcopy 4 | 5 | runs = [ 6 | # dict(acts=["gelu", "gelu", "gelu", "gelu", "gelu", "gelu"]), 7 | # dict(acts=["gelu", "gelu", "gelu", "gelu", "gelu", "gelu"], act_powers=[2, 2, 2, 2, 2, 2]), 8 | # dict(acts=["leaky", "leaky", "leaky", "leaky", "leaky", "leaky"], act_powers=[3, 3, 3, 3, 3, 3]), 9 | # dict(acts=["relu", "relu", "relu", "relu", "relu", "relu"]), 10 | # dict(acts=["relu", "relu", "relu", "relu", "relu", "relu"], act_powers=[2, 2, 2, 2, 2, 2]), 11 | # dict(acts=["relu", "relu", "relu", "relu", "relu", "relu"], act_powers=[3, 3, 3, 3, 3, 3]), 12 | # dict(acts=["relu_sin", "relu_sin", "relu_sin", "relu_sin", "relu_sin", "relu_sin"]), 13 | 14 | dict(acts=["gelu", "gelu", "gelu", "gelu", "gelu", "gelu"], val_act="gelu"), 15 | dict(acts=["gelu", "gelu", "gelu", "gelu", "gelu", "gelu"], act_powers=[2, 2, 2, 2, 2, 2], val_act="gelu"), 16 | dict(acts=["leaky", "leaky", "leaky", "leaky", "leaky", "leaky"], act_powers=[3, 3, 3, 3, 3, 3], val_act="gelu"), 17 | dict(acts=["relu", "relu", "relu", "relu", "relu", "relu"], val_act="gelu"), 18 | dict(acts=["relu", "relu", "relu", "relu", "relu", "relu"], act_powers=[2, 2, 2, 2, 2, 2], val_act="gelu"), 19 | dict(acts=["relu", "relu", "relu", "relu", "relu", "relu"], act_powers=[3, 3, 3, 3, 3, 3], val_act="gelu"), 20 | 21 | dict(acts=["gelu", "gelu", "gelu", "gelu", "gelu", "gelu"], val_act="gelu", attn_power=2), 22 | dict(acts=["gelu", "gelu", "gelu", "gelu", "gelu", "gelu"], act_powers=[2, 2, 2, 2, 2, 2], val_act="gelu", attn_power=2), 23 | dict(acts=["leaky", "leaky", "leaky", "leaky", "leaky", "leaky"], act_powers=[3, 3, 3, 3, 3, 3], val_act="gelu", attn_power=2), 24 | dict(acts=["relu", "relu", "relu", "relu", "relu", "relu"], val_act="gelu", attn_power=2), 25 | dict(acts=["relu", "relu", "relu", "relu", "relu", "relu"], act_powers=[2, 2, 2, 2, 2, 2], val_act="gelu", attn_power=2), 26 | dict(acts=["relu", "relu", "relu", "relu", "relu", "relu"], act_powers=[3, 3, 3, 3, 3, 3], val_act="gelu", attn_power=2), 27 | 28 | dict(acts=["gelu", "gelu", "gelu", "gelu", "gelu", "gelu"], val_act="leaky", attn_power=3), 29 | dict(acts=["gelu", "gelu", "gelu", "gelu", "gelu", "gelu"], act_powers=[2, 2, 2, 2, 2, 2], val_act="leaky", attn_power=3), 30 | dict(acts=["leaky", "leaky", "leaky", "leaky", "leaky", "leaky"], act_powers=[3, 3, 3, 3, 3, 3], val_act="leaky", attn_power=3), 31 | dict(acts=["relu", "relu", "relu", "relu", "relu", "relu"], val_act="leaky", attn_power=3), 32 | dict(acts=["relu", "relu", "relu", "relu", "relu", "relu"], act_powers=[2, 2, 2, 2, 2, 2], val_act="leaky", attn_power=3), 33 | dict(acts=["relu", "relu", "relu", "relu", "relu", "relu"], act_powers=[3, 3, 3, 3, 3, 3], val_act="leaky", attn_power=3), 34 | ] 35 | 36 | for run in runs: 37 | args = deepcopy(default_args) 38 | args.update(run) 39 | train_model(args) 40 | 41 | -------------------------------------------------------------------------------- /alt_activations/gpt.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | class Conv1D(nn.Module): 5 | """ 6 | 1D-convolutional layer as defined by Radford et al. for OpenAI GPT (and also used in GPT-2). 7 | 8 | Basically works like a linear layer but the weights are transposed. 9 | 10 | Args: 11 | nf (`int`): The number of output features. 12 | nx (`int`): The number of input features. 13 | """ 14 | 15 | def __init__(self, nf, nx): 16 | super().__init__() 17 | self.nf = nf 18 | self.weight = nn.Parameter(torch.empty(nx, nf)) 19 | self.bias = nn.Parameter(torch.zeros(nf)) 20 | nn.init.normal_(self.weight, std=0.02) 21 | 22 | def forward(self, x): 23 | size_out = x.size()[:-1] + (self.nf,) 24 | x = torch.addmm(self.bias, x.view(-1, x.size(-1)), self.weight) 25 | x = x.view(size_out) 26 | return x 27 | 28 | 29 | class GPT2MLP(nn.Module): 30 | def __init__(self, config, activation_type='gelu', power=1.0): 31 | super().__init__() 32 | embed_dim = config.hidden_size 33 | intermediate_size = embed_dim * 4 34 | self.c_fc = Conv1D(intermediate_size, embed_dim) 35 | self.c_proj = Conv1D(embed_dim, intermediate_size) 36 | self.act = Activation(activation_type=activation_type, power=power) 37 | self.dropout = nn.Dropout(config.resid_pdrop) 38 | 39 | def forward(self, hidden_states: Optional[Tuple[torch.FloatTensor]]) -> torch.FloatTensor: 40 | hidden_states = self.c_fc(hidden_states) 41 | hidden_states = self.act(hidden_states) 42 | hidden_states = self.c_proj(hidden_states) 43 | hidden_states = self.dropout(hidden_states) 44 | return hidden_states 45 | 46 | 47 | class NewGPT2Attention(GPT2Attention): 48 | def __init__(self, config, is_cross_attention=False, layer_idx=None, value_act=None, post_attn_act=None, power=1.0): 49 | super().__init__(config, is_cross_attention=is_cross_attention, layer_idx=layer_idx) 50 | # self.value_act = nn.Identity() if value_act is None else Activation(value_act) 51 | # self.post_attn_act = nn.Identity() if post_attn_act is None else Activation(post_attn_act) 52 | dim = self.c_attn.nf // 3 53 | self.value_act = nn.Identity() if value_act is None else LinearAct(dim, dim, activation_type=value_act, power=power, pre_act=True) 54 | self.post_attn_act = nn.Identity() if post_attn_act is None else LinearAct(dim, dim, activation_type=post_attn_act, power=power, pre_act=True) 55 | 56 | def forward( 57 | self, 58 | hidden_states: Optional[Tuple[torch.FloatTensor]], 59 | layer_past: Optional[Tuple[torch.Tensor]] = None, 60 | attention_mask: Optional[torch.FloatTensor] = None, 61 | head_mask: Optional[torch.FloatTensor] = None, 62 | encoder_hidden_states: Optional[torch.Tensor] = None, 63 | encoder_attention_mask: Optional[torch.FloatTensor] = None, 64 | use_cache: Optional[bool] = False, 65 | output_attentions: Optional[bool] = False, 66 | ) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]], ...]: 67 | if encoder_hidden_states is not None: 68 | if not hasattr(self, "q_attn"): 69 | raise ValueError( 70 | "If class is used as cross attention, the weights `q_attn` have to be defined. " 71 | "Please make sure to instantiate class with `GPT2Attention(..., is_cross_attention=True)`." 72 | ) 73 | 74 | query = self.q_attn(hidden_states) 75 | key, value = self.c_attn(encoder_hidden_states).split(self.split_size, dim=2) 76 | attention_mask = encoder_attention_mask 77 | else: 78 | query, key, value = self.c_attn(hidden_states).split(self.split_size, dim=2) 79 | 80 | value = self.value_act(value) 81 | 82 | query = self._split_heads(query, self.num_heads, self.head_dim) 83 | key = self._split_heads(key, self.num_heads, self.head_dim) 84 | value = self._split_heads(value, self.num_heads, self.head_dim) 85 | 86 | if layer_past is not None: 87 | past_key, past_value = layer_past 88 | key = torch.cat((past_key, key), dim=-2) 89 | value = torch.cat((past_value, value), dim=-2) 90 | 91 | if use_cache is True: 92 | present = (key, value) 93 | else: 94 | present = None 95 | 96 | if self.reorder_and_upcast_attn: 97 | attn_output, attn_weights = self._upcast_and_reordered_attn(query, key, value, attention_mask, head_mask) 98 | else: 99 | attn_output, attn_weights = self._attn(query, key, value, attention_mask, head_mask) 100 | 101 | attn_output = self.post_attn_act(attn_output) 102 | 103 | attn_output = self._merge_heads(attn_output, self.num_heads, self.head_dim) 104 | attn_output = self.c_proj(attn_output) 105 | attn_output = self.resid_dropout(attn_output) 106 | 107 | outputs = (attn_output, present) 108 | if output_attentions: 109 | outputs += (attn_weights,) 110 | 111 | return outputs # a, present, (attentions) 112 | 113 | 114 | def patch_attn(model, value_act=None, post_attn_act=None, power=1.0): 115 | conf = model.config 116 | idx = 0 117 | for n,m in model.named_modules(): 118 | if hasattr(m, "attn"): 119 | del m.attn 120 | m.add_module("attn", NewGPT2Attention(conf, is_cross_attention=False, layer_idx=None, value_act=value_act, post_attn_act=post_attn_act, power=power)) 121 | idx += 1 122 | print('current idx', idx) 123 | 124 | 125 | def patch_mlp(model, activation_type, activation_powers): 126 | idx = 0 127 | for n,m in model.named_modules(): 128 | if hasattr(m, "mlp"): 129 | del m.mlp 130 | m.add_module("mlp", GPT2MLP(model.config, activation_type=activation_type[idx], power=activation_powers[idx])) 131 | idx += 1 132 | 133 | 134 | 135 | def run_name(): 136 | base_str = "base" 137 | if args["value_act"] is not None: 138 | base_str = f"{base_str}_vact_{args['value_act']}" 139 | if args["post_attn_act"] is not None: 140 | base_str = f"{base_str}_pact_{args['post_attn_act']}" 141 | args["output_dir"] = f"{args['base_output_dir']}/{base_str}" 142 | 143 | unique_activations = list(set(args['activations'])) 144 | non_gelu = [a for a in unique_activations if a != "gelu"] 145 | if len(non_gelu) > 0: 146 | non_gelu = non_gelu[0] 147 | indices = tuple([i+1 for i, a in enumerate(args['activations']) if a == non_gelu]) 148 | base_str = base_str + "_{}-{}".format(non_gelu, indices) 149 | 150 | unique_powers = list(set(args['activation_powers'])) 151 | non_one = [p for p in unique_powers if p != 1] 152 | if len(non_one) > 0: 153 | non_one = non_one[0] 154 | indices = tuple([i+1 for i, p in enumerate(args['activation_powers']) if p == non_one]) 155 | base_str = base_str + "_{}-{}".format(non_one, indices) 156 | 157 | 158 | extra_args = { 159 | 160 | "activations": ["gelu", "gelu", "gelu", "gelu", "gelu", "gelu", "gelu", "gelu", "gelu", "gelu", "gelu", "gelu"], 161 | "activation_powers": [2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2], 162 | "value_act": "leaky", 163 | "post_attn_act": None, 164 | "attn_power": 3.0, 165 | } -------------------------------------------------------------------------------- /alt_activations/models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ethansmith2000/TransformerExperiments/eb77920ba61e42b87c2090fb01b2c94183cfe43b/alt_activations/models/__init__.py -------------------------------------------------------------------------------- /alt_activations/models/resnet.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | '''ResNet in PyTorch. 4 | For Pre-activation ResNet, see 'preact_resnet.py'. 5 | Reference: 6 | [1] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun 7 | Deep Residual Learning for Image Recognition. arXiv:1512.03385 8 | ''' 9 | import torch 10 | import torch.nn as nn 11 | import torch.nn.functional as F 12 | 13 | 14 | class BasicBlock(nn.Module): 15 | expansion = 1 16 | 17 | def __init__(self, in_planes, planes, stride=1): 18 | super(BasicBlock, self).__init__() 19 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 20 | self.bn1 = nn.BatchNorm2d(planes) 21 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False) 22 | self.bn2 = nn.BatchNorm2d(planes) 23 | 24 | self.shortcut = nn.Sequential() 25 | if stride != 1 or in_planes != self.expansion*planes: 26 | self.shortcut = nn.Sequential( 27 | nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False), 28 | nn.BatchNorm2d(self.expansion*planes) 29 | ) 30 | 31 | def forward(self, x): 32 | out = F.relu(self.bn1(self.conv1(x))) 33 | out = self.bn2(self.conv2(out)) 34 | out += self.shortcut(x) 35 | out = F.relu(out) 36 | return out 37 | 38 | 39 | class Bottleneck(nn.Module): 40 | expansion = 4 41 | 42 | def __init__(self, in_planes, planes, stride=1): 43 | super(Bottleneck, self).__init__() 44 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False) 45 | self.bn1 = nn.BatchNorm2d(planes) 46 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 47 | self.bn2 = nn.BatchNorm2d(planes) 48 | self.conv3 = nn.Conv2d(planes, self.expansion*planes, kernel_size=1, bias=False) 49 | self.bn3 = nn.BatchNorm2d(self.expansion*planes) 50 | 51 | self.shortcut = nn.Sequential() 52 | if stride != 1 or in_planes != self.expansion*planes: 53 | self.shortcut = nn.Sequential( 54 | nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False), 55 | nn.BatchNorm2d(self.expansion*planes) 56 | ) 57 | 58 | def forward(self, x): 59 | out = F.relu(self.bn1(self.conv1(x))) 60 | out = F.relu(self.bn2(self.conv2(out))) 61 | out = self.bn3(self.conv3(out)) 62 | out += self.shortcut(x) 63 | out = F.relu(out) 64 | return out 65 | 66 | 67 | class ResNet(nn.Module): 68 | def __init__(self, block, num_blocks, num_classes=10): 69 | super(ResNet, self).__init__() 70 | self.in_planes = 64 71 | 72 | self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False) 73 | self.bn1 = nn.BatchNorm2d(64) 74 | self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1) 75 | self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2) 76 | self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2) 77 | self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2) 78 | self.linear = nn.Linear(512*block.expansion, num_classes) 79 | 80 | def _make_layer(self, block, planes, num_blocks, stride): 81 | strides = [stride] + [1]*(num_blocks-1) 82 | layers = [] 83 | for stride in strides: 84 | layers.append(block(self.in_planes, planes, stride)) 85 | self.in_planes = planes * block.expansion 86 | return nn.Sequential(*layers) 87 | 88 | 89 | def get_feat(self, x): 90 | out = F.relu(self.bn1(self.conv1(x))) 91 | out = self.layer1(out) 92 | out = self.layer2(out) 93 | out = self.layer3(out) 94 | out = self.layer4(out) 95 | out = F.avg_pool2d(out, 4) 96 | out = out.view(out.size(0), -1) 97 | return out 98 | 99 | def forward(self, img): 100 | x = self.get_feat(img) 101 | return self.linear(x) 102 | 103 | 104 | def ResNet18(): 105 | return ResNet(BasicBlock, [2,2,2,2]) 106 | 107 | def ResNet34(): 108 | return ResNet(BasicBlock, [3,4,6,3]) 109 | 110 | def ResNet50(): 111 | return ResNet(Bottleneck, [3,4,6,3]) 112 | 113 | def ResNet101(): 114 | return ResNet(Bottleneck, [3,4,23,3]) 115 | 116 | def ResNet152(): 117 | return ResNet(Bottleneck, [3,8,36,3]) 118 | 119 | 120 | # def test(): 121 | # net = ResNet18() 122 | # y = net(torch.randn(1,3,32,32)) 123 | # print(y.size()) 124 | 125 | # test() -------------------------------------------------------------------------------- /alt_activations/vit.py: -------------------------------------------------------------------------------- 1 | # https://github.com/lucidrains/vit-pytorch/blob/main/vit_pytorch/vit.py 2 | 3 | import torch 4 | from torch import nn 5 | 6 | from einops import rearrange, repeat 7 | from einops.layers.torch import Rearrange 8 | 9 | from common.activations import Activation, LinearAct 10 | 11 | class AltActFeedForward(nn.Module): 12 | def __init__(self, dim, hidden_dim, dropout = 0., act="gelu", act_power=1): 13 | super().__init__() 14 | self.net = nn.Sequential( 15 | nn.Linear(dim, hidden_dim), 16 | Activation(act, power=act_power, dim=hidden_dim), 17 | nn.Dropout(dropout), 18 | nn.Linear(hidden_dim, dim), 19 | nn.Dropout(dropout) 20 | ) 21 | def forward(self, x): 22 | return self.net(x) 23 | 24 | class AltActAttention(nn.Module): 25 | def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0., val_act=None, post_attn_act=None, power=1.0): 26 | super().__init__() 27 | inner_dim = dim_head * heads 28 | project_out = not (heads == 1 and dim_head == dim) 29 | self.heads = heads 30 | self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False) 31 | self.to_out = nn.Sequential( 32 | nn.Linear(inner_dim, dim), 33 | nn.Dropout(dropout) 34 | ) if project_out else nn.Identity() 35 | # self.val_act = nn.Identity() if val_act is None else Activation(val_act) 36 | # self.post_attn_act = nn.Identity() if post_attn_act is None else Activation(post_attn_act) 37 | self.val_act = LinearAct(inner_dim, inner_dim, activation_type=val_act, power=power, pre_act=True) if val_act is not None else nn.Identity() 38 | self.post_attn_act = LinearAct(dim, dim, activation_type=post_attn_act, power=power, pre_act=True) if post_attn_act is not None else nn.Identity() 39 | 40 | 41 | def forward(self, x): 42 | q, k, v = self.to_qkv(x).chunk(3, dim = -1) 43 | v = self.val_act(v) 44 | q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), (q,k,v)) 45 | out = torch.nn.functional.scaled_dot_product_attention(q, k, v) 46 | out = rearrange(out, 'b h n d -> b n (h d)') 47 | out = self.post_attn_act(out) 48 | return self.to_out(out) 49 | 50 | 51 | def patch_model(model, args, exp_args): 52 | for name, module in model.named_modules(): 53 | if hasattr(module, 'attn'): 54 | module.attn = AltActAttention(args.dim, args.heads, args.dim_head, args.dropout, args.val_act, args.post_attn_act, args.act_power) 55 | if hasattr(module, 'ff'): 56 | module.ff = AltActFeedForward(args.dim, args.mlp_dim, args.dropout, args.act, args.act_power) 57 | 58 | # init weights again 59 | model.init_weights() 60 | 61 | return model 62 | -------------------------------------------------------------------------------- /attention_supervision/cifar_batch.py: -------------------------------------------------------------------------------- 1 | 2 | from train_cifar10 import train_model, default_args 3 | from copy import deepcopy 4 | 5 | runs = [ 6 | # dict(acts=["gelu", "gelu", "gelu", "gelu", "gelu", "gelu"]), 7 | # dict(acts=["gelu", "gelu", "gelu", "gelu", "gelu", "gelu"], act_powers=[2, 2, 2, 2, 2, 2]), 8 | # dict(acts=["leaky", "leaky", "leaky", "leaky", "leaky", "leaky"], act_powers=[3, 3, 3, 3, 3, 3]), 9 | # dict(acts=["relu", "relu", "relu", "relu", "relu", "relu"]), 10 | # dict(acts=["relu", "relu", "relu", "relu", "relu", "relu"], act_powers=[2, 2, 2, 2, 2, 2]), 11 | # dict(acts=["relu", "relu", "relu", "relu", "relu", "relu"], act_powers=[3, 3, 3, 3, 3, 3]), 12 | # dict(acts=["relu_sin", "relu_sin", "relu_sin", "relu_sin", "relu_sin", "relu_sin"]), 13 | 14 | dict(acts=["gelu", "gelu", "gelu", "gelu", "gelu", "gelu"], val_act="gelu"), 15 | dict(acts=["gelu", "gelu", "gelu", "gelu", "gelu", "gelu"], act_powers=[2, 2, 2, 2, 2, 2], val_act="gelu"), 16 | dict(acts=["leaky", "leaky", "leaky", "leaky", "leaky", "leaky"], act_powers=[3, 3, 3, 3, 3, 3], val_act="gelu"), 17 | dict(acts=["relu", "relu", "relu", "relu", "relu", "relu"], val_act="gelu"), 18 | dict(acts=["relu", "relu", "relu", "relu", "relu", "relu"], act_powers=[2, 2, 2, 2, 2, 2], val_act="gelu"), 19 | dict(acts=["relu", "relu", "relu", "relu", "relu", "relu"], act_powers=[3, 3, 3, 3, 3, 3], val_act="gelu"), 20 | 21 | dict(acts=["gelu", "gelu", "gelu", "gelu", "gelu", "gelu"], val_act="gelu", attn_power=2), 22 | dict(acts=["gelu", "gelu", "gelu", "gelu", "gelu", "gelu"], act_powers=[2, 2, 2, 2, 2, 2], val_act="gelu", attn_power=2), 23 | dict(acts=["leaky", "leaky", "leaky", "leaky", "leaky", "leaky"], act_powers=[3, 3, 3, 3, 3, 3], val_act="gelu", attn_power=2), 24 | dict(acts=["relu", "relu", "relu", "relu", "relu", "relu"], val_act="gelu", attn_power=2), 25 | dict(acts=["relu", "relu", "relu", "relu", "relu", "relu"], act_powers=[2, 2, 2, 2, 2, 2], val_act="gelu", attn_power=2), 26 | dict(acts=["relu", "relu", "relu", "relu", "relu", "relu"], act_powers=[3, 3, 3, 3, 3, 3], val_act="gelu", attn_power=2), 27 | 28 | dict(acts=["gelu", "gelu", "gelu", "gelu", "gelu", "gelu"], val_act="leaky", attn_power=3), 29 | dict(acts=["gelu", "gelu", "gelu", "gelu", "gelu", "gelu"], act_powers=[2, 2, 2, 2, 2, 2], val_act="leaky", attn_power=3), 30 | dict(acts=["leaky", "leaky", "leaky", "leaky", "leaky", "leaky"], act_powers=[3, 3, 3, 3, 3, 3], val_act="leaky", attn_power=3), 31 | dict(acts=["relu", "relu", "relu", "relu", "relu", "relu"], val_act="leaky", attn_power=3), 32 | dict(acts=["relu", "relu", "relu", "relu", "relu", "relu"], act_powers=[2, 2, 2, 2, 2, 2], val_act="leaky", attn_power=3), 33 | dict(acts=["relu", "relu", "relu", "relu", "relu", "relu"], act_powers=[3, 3, 3, 3, 3, 3], val_act="leaky", attn_power=3), 34 | ] 35 | 36 | for run in runs: 37 | args = deepcopy(default_args) 38 | args.update(run) 39 | train_model(args) 40 | 41 | -------------------------------------------------------------------------------- /boneless_attn/gpt.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from typing import Optional, Tuple 4 | from transformers.models.gpt2.modeling_gpt2 import GPT2Attention, eager_attention_forward, ALL_ATTENTION_FUNCTIONS 5 | 6 | 7 | class BaselineAttention(GPT2Attention): 8 | """ 9 | Same attention class, but patched to let wq, wk, wv to be seperate matrices 10 | """ 11 | 12 | def __init__(self, config, is_cross_attention=False, layer_idx=None, **kwargs): 13 | super().__init__(config, is_cross_attention, layer_idx) 14 | del self.c_attn 15 | del self.c_proj 16 | self.q_attn = nn.Linear(config.n_embd, config.n_embd, bias=False) 17 | self.k_attn = nn.Linear(config.n_embd, config.n_embd, bias=False) 18 | self.v_attn = nn.Linear(config.n_embd, config.n_embd, bias=False) 19 | self.o_attn = nn.Linear(config.n_embd, config.n_embd, bias=True) 20 | 21 | def forward( 22 | self, 23 | hidden_states: Optional[Tuple[torch.FloatTensor]], 24 | layer_past: Optional[Tuple[torch.Tensor]] = None, 25 | attention_mask: Optional[torch.FloatTensor] = None, 26 | head_mask: Optional[torch.FloatTensor] = None, 27 | encoder_hidden_states: Optional[torch.Tensor] = None, 28 | encoder_attention_mask: Optional[torch.FloatTensor] = None, 29 | use_cache: Optional[bool] = False, 30 | output_attentions: Optional[bool] = False, 31 | **kwargs, 32 | ): 33 | # query_states, key_states, value_states = self.c_attn(hidden_states).split(self.split_size, dim=2) 34 | query_states = self.q_attn(hidden_states) 35 | key_states = self.k_attn(hidden_states) 36 | value_states = self.v_attn(hidden_states) 37 | 38 | shape_q = (*query_states.shape[:-1], -1, self.head_dim) 39 | shape_kv = (*key_states.shape[:-1], -1, self.head_dim) 40 | 41 | query_states = query_states.view(shape_q).transpose(1, 2) 42 | key_states = key_states.view(shape_kv).transpose(1, 2) 43 | value_states = value_states.view(shape_kv).transpose(1, 2) 44 | 45 | if layer_past is not None: 46 | past_key, past_value = layer_past 47 | key_states = torch.cat((past_key, key_states), dim=-2) 48 | value_states = torch.cat((past_value, value_states), dim=-2) 49 | 50 | if use_cache is True: 51 | present = (key_states, value_states) 52 | else: 53 | present = None 54 | 55 | is_cross_attention = encoder_hidden_states is not None 56 | is_causal = attention_mask is None and query_states.shape[-2] > 1 and not is_cross_attention 57 | 58 | using_eager = self.config._attn_implementation == "eager" 59 | attention_interface = eager_attention_forward 60 | if self.config._attn_implementation != "eager": 61 | if self.config._attn_implementation == "sdpa" and (output_attentions or head_mask is not None): 62 | using_eager = True 63 | else: 64 | # Attention functions are consistent with previous equivalent attention classes, however they do not support some options 65 | # (e.g. layer scaling, head mask) that eager supports. These implementations are thus equivalent to previous code, but 66 | # not necessarily to eager (if mentionned options are provided). 67 | attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] 68 | 69 | if using_eager and self.reorder_and_upcast_attn: 70 | attn_output, attn_weights = self._upcast_and_reordered_attn( 71 | query_states, key_states, value_states, attention_mask, head_mask 72 | ) 73 | else: 74 | attn_output, attn_weights = attention_interface( 75 | self, 76 | query_states, 77 | key_states, 78 | value_states, 79 | attention_mask, 80 | head_mask=head_mask, 81 | dropout=self.attn_dropout.p if self.training else 0.0, 82 | is_causal=is_causal, 83 | **kwargs, 84 | ) 85 | 86 | attn_output = attn_output.reshape(*attn_output.shape[:-2], -1).contiguous() 87 | attn_output = self.o_attn(attn_output) 88 | attn_output = self.resid_dropout(attn_output) 89 | 90 | outputs = (attn_output, present) 91 | if output_attentions: 92 | outputs += (attn_weights,) 93 | 94 | return outputs 95 | 96 | 97 | 98 | class CorrelationAttention(BaselineAttention): 99 | def __init__(self, config, is_cross_attention=False, layer_idx=None, **kwargs): 100 | super().__init__(config, is_cross_attention, layer_idx) 101 | if kwargs.get("mod_q", False): 102 | self.q_attn = nn.Identity() 103 | if kwargs.get("mod_k", False): 104 | self.k_attn = nn.Identity() 105 | 106 | # bonus 107 | if kwargs.get("mod_v", False): 108 | self.v_attn = nn.Identity() 109 | if kwargs.get("mod_o", False): 110 | self.o_attn = nn.Identity() 111 | 112 | def init_to_identity(module): 113 | if isinstance(module, nn.Linear): 114 | module.weight.data.fill_(0.0) 115 | module.weight.data[:, :] = torch.eye(module.weight.data.shape[1]) 116 | 117 | class CorrelationInitAttention(BaselineAttention): 118 | def __init__(self, config, is_cross_attention=False, layer_idx=None, **kwargs): 119 | super().__init__(config, is_cross_attention, layer_idx) 120 | if kwargs.get("mod_q", False): 121 | init_to_identity(self.q_attn) 122 | if kwargs.get("mod_k", False): 123 | init_to_identity(self.k_attn) 124 | 125 | # bonus 126 | if kwargs.get("mod_v", False): 127 | init_to_identity(self.v_attn) 128 | if kwargs.get("mod_o", False): 129 | init_to_identity(self.o_attn) 130 | 131 | 132 | class ResidualLinear(nn.Linear): 133 | def forward(self, input): 134 | return input + super().forward(input) 135 | 136 | # class ResidualLinear(nn.Linear): 137 | # def __init__(self, in_features, out_features, bias=False, scale=1.0, trainable_scale=False): 138 | # super().__init__(in_features, out_features, bias) 139 | # scale = torch.Tensor([scale]) 140 | # if trainable_scale: 141 | # self.scale = nn.Parameter(scale) 142 | # else: 143 | # self.register_buffer("scale", scale) 144 | 145 | # def forward(self, input): 146 | # return input * self.scale + super().forward(input) * (1 - self.scale) 147 | 148 | 149 | class ResidualAttention(BaselineAttention): 150 | def __init__(self, config, is_cross_attention=False, layer_idx=None, **kwargs): 151 | super().__init__(config, is_cross_attention, layer_idx) 152 | if kwargs.get("mod_q", False): 153 | self.q_attn = ResidualLinear(config.n_embd, config.n_embd, bias=False) 154 | # nn.init.xavier_uniform_(self.q_attn.weight) 155 | if kwargs.get("mod_k", False): 156 | self.k_attn = ResidualLinear(config.n_embd, config.n_embd, bias=False) 157 | # nn.init.xavier_uniform_(self.k_attn.weight) 158 | # bonus 159 | if kwargs.get("mod_v", False): 160 | self.v_attn = ResidualLinear(config.n_embd, config.n_embd, bias=False) 161 | # nn.init.xavier_uniform_(self.v_attn.weight) 162 | if kwargs.get("mod_o", False): 163 | self.o_attn = ResidualLinear(config.n_embd, config.n_embd, bias=True) 164 | # nn.init.xavier_uniform_(self.o_attn.weight) 165 | 166 | 167 | def patch_model(model, args, exp_args): 168 | idx = 0 169 | for n,m in model.named_modules(): 170 | if hasattr(m, "attn"): 171 | if exp_args["mode"] == "correlation": 172 | m.attn = CorrelationAttention(model.config, **exp_args) 173 | elif exp_args["mode"] == "correlation_init": 174 | m.attn = CorrelationInitAttention(model.config, **exp_args) 175 | elif exp_args["mode"] == "residual": 176 | m.attn = ResidualAttention(model.config, **exp_args) 177 | elif exp_args["mode"] == "base": 178 | m.attn = BaselineAttention(model.config, **exp_args) 179 | elif exp_args["mode"] == "dummy": 180 | m.attn = nn.Identity() 181 | else: 182 | raise ValueError(f"Invalid mode: {exp_args['mode']}") 183 | 184 | return model 185 | 186 | 187 | 188 | def get_run_name(args, exp_args): 189 | run_name = "mode_" + exp_args["mode"] + "_modq_" + str(exp_args["mod_q"]) + "_modk_" + str(exp_args["mod_k"]) + "_modv_" + str(exp_args["mod_v"]) + "_ modo_" + str(exp_args["mod_o"]) + "_lr:" + str(args["learning_rate"]) 190 | args["output_dir"] = f"{args['base_output_dir']}/{run_name}" 191 | 192 | return args, run_name 193 | 194 | 195 | extra_args = { 196 | # "mode": ["correlation", "correlation_init", "residual", "dummy", "base"], 197 | # "mod_q": [True, False], 198 | # "mod_k": [True, False], 199 | # "mod_v": [True, False], 200 | # "mod_o": [True, False], 201 | 202 | # baseline 203 | "mode": "base", 204 | "mod_q": False, 205 | "mod_k": False, 206 | "mod_v": False, 207 | "mod_o": False, 208 | 209 | # correlation, no v/o 210 | # "mode": "correlation", 211 | # "mod_q": True, 212 | # "mod_k": True, 213 | # "mod_v": False, 214 | # "mod_o": False, 215 | 216 | # pure correlation 217 | # "mode": "correlation", 218 | # "mod_q": True, 219 | # "mod_k": True, 220 | # "mod_v": True, 221 | # "mod_o": True, 222 | 223 | # init correlation no v/o 224 | # "mode": "correlation_init", 225 | # "mod_q": True, 226 | # "mod_k": True, 227 | # "mod_v": False, 228 | # "mod_o": False, 229 | 230 | # init correlation 231 | # "mode": "correlation_init", 232 | # "mod_q": True, 233 | # "mod_k": True, 234 | # "mod_v": True, 235 | # "mod_o": True, 236 | 237 | # residual no v/o 238 | # "mode": "residual", 239 | # "mod_q": True, 240 | # "mod_k": True, 241 | # "mod_v": False, 242 | # "mod_o": False, 243 | # "trainable_scale": True, 244 | 245 | # residual 246 | # "mode": "residual", 247 | # "mod_q": True, 248 | # "mod_k": True, 249 | # "mod_v": True, 250 | # "mod_o": True, 251 | # "trainable_scale": True, 252 | 253 | 254 | 255 | 256 | 257 | } 258 | -------------------------------------------------------------------------------- /boneless_attn/vit.py: -------------------------------------------------------------------------------- 1 | # https://github.com/lucidrains/vit-pytorch/blob/main/vit_pytorch/vit.py 2 | 3 | import torch 4 | from torch import nn 5 | from einops import rearrange, repeat 6 | from einops.layers.torch import Rearrange 7 | 8 | 9 | class BaselineAttention(nn.Module): 10 | def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0., **kwargs): 11 | super().__init__() 12 | inner_dim = dim_head * heads 13 | self.heads = heads 14 | self.to_q = nn.Linear(dim, inner_dim, bias = False) 15 | self.to_k = nn.Linear(dim, inner_dim, bias = False) 16 | self.to_v = nn.Linear(dim, inner_dim, bias = False) 17 | self.to_o = nn.Linear(inner_dim, dim, bias = True) 18 | self.to_out = nn.Sequential( 19 | nn.Linear(inner_dim, dim), 20 | nn.Dropout(dropout) 21 | ) 22 | # xavier init 23 | nn.init.xavier_normal_(self.to_q.weight) 24 | nn.init.xavier_normal_(self.to_k.weight) 25 | nn.init.xavier_normal_(self.to_v.weight) 26 | nn.init.xavier_normal_(self.to_o[0].weight) 27 | 28 | 29 | def forward(self, x): 30 | q, k, v = self.to_q(x), self.to_k(x), self.to_v(x) 31 | q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), (q,k,v)) 32 | out = torch.nn.functional.scaled_dot_product_attention(q, k, v) 33 | out = rearrange(out, 'b h n d -> b n (h d)') 34 | out = self.to_out(out) 35 | return out 36 | 37 | 38 | class CorrelationAttention(BaselineAttention): 39 | def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0., **kwargs): 40 | super().__init__(dim, heads, dim_head, dropout) 41 | if kwargs.get("mod_q", False): 42 | self.to_q = nn.Identity() 43 | if kwargs.get("mod_k", False): 44 | self.to_k = nn.Identity() 45 | if kwargs.get("mod_v", False): 46 | self.to_v = nn.Identity() 47 | if kwargs.get("mod_o", False): 48 | self.to_o = nn.Sequential( 49 | nn.Identity(), 50 | nn.Dropout(dropout) 51 | ) 52 | 53 | def init_to_identity(module): 54 | if isinstance(module, nn.Linear): 55 | module.weight.data.fill_(0.0) 56 | module.weight.data[:, :] = torch.eye(module.weight.data.shape[1]) 57 | 58 | class CorrelationInitAttention(BaselineAttention): 59 | def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0., **kwargs): 60 | super().__init__(dim, heads, dim_head, dropout) 61 | if kwargs.get("mod_q", False): 62 | init_to_identity(self.to_q) 63 | if kwargs.get("mod_k", False): 64 | init_to_identity(self.to_k) 65 | if kwargs.get("mod_v", False): 66 | init_to_identity(self.to_v) 67 | if kwargs.get("mod_o", False): 68 | init_to_identity(self.to_o[0]) 69 | 70 | # class ResidualLinear(nn.Linear): 71 | # def forward(self, input): 72 | # return input + super().forward(input) 73 | 74 | class ResidualLinear(nn.Linear): 75 | def __init__(self, in_features, out_features, bias=False, scale=1.0): 76 | super().__init__(in_features, out_features, bias) 77 | scale = torch.Tensor([scale]) 78 | self.register_buffer("scale", scale) 79 | 80 | def forward(self, input): 81 | return input * self.scale + super().forward(input) * self.scale 82 | 83 | class ResidualAttention(BaselineAttention): 84 | def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0., **kwargs): 85 | super().__init__(dim, heads, dim_head, dropout) 86 | if kwargs.get("mod_q", False): 87 | self.to_q = ResidualLinear(dim, dim, bias=False) 88 | if kwargs.get("mod_k", False): 89 | self.to_k = ResidualLinear(dim, dim, bias=False) 90 | if kwargs.get("mod_v", False): 91 | self.to_v = ResidualLinear(dim, dim, bias=False) 92 | if kwargs.get("mod_o", False): 93 | self.to_o = nn.Sequential( 94 | ResidualLinear(dim, dim, bias=True), 95 | nn.Dropout(dropout) 96 | ) 97 | 98 | def patch_model(model, args, exp_args): 99 | for name, m in model.named_modules(): 100 | if hasattr(m, "attn"): 101 | if exp_args["mode"] == "correlation": 102 | m.attn = CorrelationAttention(model.config, **exp_args) 103 | elif exp_args["mode"] == "correlation_init": 104 | m.attn = CorrelationInitAttention(model.config, **exp_args) 105 | elif exp_args["mode"] == "residual": 106 | m.attn = ResidualAttention(model.config, **exp_args) 107 | elif exp_args["mode"] == "base": 108 | m.attn = BaselineAttention(model.config, **exp_args) 109 | elif exp_args["mode"] == "dummy": 110 | m.attn = nn.Identity() 111 | else: 112 | raise ValueError(f"Invalid mode: {exp_args['mode']}") 113 | 114 | return model -------------------------------------------------------------------------------- /calibrated_attention/gpt.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from typing import Optional, Tuple 4 | from .attentions import ( 5 | AttentionBase, 6 | AttentionRelativeScaling1, 7 | AttentionRelativeScaling2, 8 | AttentionYarnScaling, 9 | AttentionPolyFitScaling, 10 | AttentionLearnedScaling, 11 | AttentionSoftmaxPlusOne, 12 | AttentionSoftmaxPlusFN 13 | ) 14 | def patch_model(model, args, exp_args): 15 | for n,m in model.named_modules(): 16 | if hasattr(m, "attn"): 17 | dim = model.config.n_embd 18 | heads = m.attn.num_heads 19 | 20 | if exp_args["mode"] == "base": 21 | m.attn = AttentionBase(dim=dim, heads=heads) 22 | elif exp_args["mode"] == "relative_scaling_1": 23 | m.attn = AttentionRelativeScaling1(dim=dim, heads=heads, base_seq_len=exp_args["base_seq_len"]) 24 | elif exp_args["mode"] == "relative_scaling_2": 25 | m.attn = AttentionRelativeScaling2(dim=dim, heads=heads, base_seq_len=exp_args["base_seq_len"], attn_bias=exp_args["attn_bias"], learned_bias=exp_args["learned_bias"]) 26 | elif exp_args["mode"] == "yarn_scaling": 27 | m.attn = AttentionYarnScaling(dim=dim, heads=heads) 28 | elif exp_args["mode"] == "polyfit_scaling": 29 | m.attn = AttentionPolyFitScaling(dim=dim, heads=heads) 30 | elif exp_args["mode"] == "learned_scaling": 31 | m.attn = AttentionLearnedScaling(dim=dim, heads=heads) 32 | elif exp_args["mode"] == "softmax_plus_one": 33 | m.attn = AttentionSoftmaxPlusOne(dim=dim, heads=heads) 34 | elif exp_args["mode"] == "softmax_plus_fn": 35 | m.attn = AttentionSoftmaxPlusFN(dim=dim, heads=heads) 36 | else: 37 | raise ValueError(f"Invalid mode: {exp_args['mode']}") 38 | 39 | return model 40 | 41 | 42 | 43 | def get_run_name(args, exp_args): 44 | if exp_args["mode"] in ["base", "yarn_scaling", "polyfit_scaling", "learned_scaling", "softmax_plus_one", "softmax_plus_fn"]: 45 | run_name = f"{exp_args['mode']}_lr:{args['learning_rate']}" 46 | elif exp_args["mode"] in ["relative_scaling_1", "relative_scaling_2"]: 47 | run_name = f"{exp_args['mode']}_base_seq_len:{exp_args['base_seq_len']}_lr:{args['learning_rate']}" 48 | elif exp_args["mode"] in ["relative_scaling_2"]: 49 | run_name = f"{exp_args['mode']}_base_seq_len:{exp_args['base_seq_len']}_attn_bias:{exp_args['attn_bias']}_learned_bias:{exp_args['learned_bias']}_lr:{args['learning_rate']}" 50 | else: 51 | raise ValueError(f"Invalid mode: {exp_args['mode']}") 52 | args["output_dir"] = f"{args['base_output_dir']}/{run_name}" 53 | 54 | return args, run_name 55 | 56 | 57 | extra_args = { 58 | # base 59 | # "mode": "base", 60 | 61 | # # relative scaling 62 | # "mode": "relative_scaling_1", 63 | # "base_seq_len": 256, 64 | 65 | # # relative scaling 2 66 | # "mode": "relative_scaling_2", 67 | # "base_seq_len": 256, 68 | # "attn_bias": 1.5, 69 | # "learned_bias": False, 70 | 71 | # # softmax plus one 72 | # "mode": "softmax_plus_one", 73 | 74 | # # softmax plus fn 75 | # "mode": "softmax_plus_fn", 76 | 77 | # # polyfit scaling 78 | # "mode": "polyfit_scaling", 79 | 80 | # # learned scaling 81 | "mode": "learned_scaling", 82 | 83 | # # softmax plus one 84 | # "mode": "softmax_plus_one", 85 | 86 | # # softmax plus fn 87 | # "mode": "softmax_plus_fn", 88 | 89 | 90 | } 91 | -------------------------------------------------------------------------------- /common/__init__.py: -------------------------------------------------------------------------------- 1 | from .activations import Activation -------------------------------------------------------------------------------- /common/activations.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class SignGelu2(nn.Module): 7 | 8 | def __init__(self, neg_scale=20.0): 9 | super().__init__() 10 | self.neg_scale = neg_scale 11 | 12 | def forward(self, x): 13 | sign = torch.sign(x) 14 | sign = torch.where(sign < 0, sign * self.neg_scale, sign) 15 | return sign * F.gelu(x).square() 16 | 17 | 18 | class SinLU(nn.Module): 19 | def __init__(self, dim=None): 20 | super(SinLU,self).__init__() 21 | dim = 1 if dim is None else dim 22 | self.a = nn.Parameter(torch.ones(dim)) 23 | self.b = nn.Parameter(torch.ones(dim)) 24 | def forward(self,x): 25 | return torch.sigmoid(x)*(x+self.a*torch.sin(self.b*x)) 26 | 27 | 28 | class NormalizedExp(nn.Module): 29 | 30 | def __init__(self, beta=0.99): 31 | super().__init__() 32 | self.register_buffer("avg_max", torch.tensor(1.0)) 33 | self.beta = beta 34 | 35 | def forward(self, x): 36 | max_val = torch.max(x) / 2 37 | self.avg_max = self.beta * self.avg_max + (1 - self.beta) * max_val 38 | return torch.exp(x - self.avg_max) 39 | 40 | 41 | class LinearAct(nn.Module): 42 | 43 | def __init__(self, in_features, out_features, bias=True, activation_type='relu', power=1.0, pre_act=False): 44 | super().__init__() 45 | self.pre_act = Activation(activation_type, power, dim=in_features) if pre_act else nn.Identity() 46 | self.linear = nn.Linear(in_features, out_features, bias=bias) 47 | self.post_act = Activation(activation_type, power, dim=out_features) if not pre_act else nn.Identity() 48 | 49 | def forward(self, x): 50 | return self.post_act(self.linear(self.pre_act(x))) 51 | 52 | 53 | from torch.autograd import Function 54 | 55 | class ReluForwardSiluBackward(Function): 56 | @staticmethod 57 | def forward(ctx, input): 58 | ctx.save_for_backward(input) 59 | return input.clamp(min=0) 60 | 61 | @staticmethod 62 | def backward(ctx, grad_output): 63 | input = ctx.saved_tensors[0] 64 | sigmoid = torch.sigmoid(input) 65 | grad_input = grad_output * (sigmoid * (1 + input * (1 - sigmoid))) 66 | return grad_input 67 | 68 | relu_fwd_silu_bwd = ReluForwardSiluBackward.apply 69 | 70 | class Activation(nn.Module): 71 | 72 | def __init__(self, activation_type: str = 'relu', power=1.0, dim=None): 73 | super().__init__() 74 | self.activation_type = activation_type 75 | if activation_type == 'relu': 76 | activation = lambda x: torch.nn.functional.relu(x) 77 | elif activation_type == 'gelu': 78 | activation = lambda x: torch.nn.functional.gelu(x) 79 | elif activation_type == 'silu': 80 | activation = lambda x: torch.nn.functional.silu(x) 81 | elif activation_type == 'tanh': 82 | activation = lambda x: torch.tanh(x) 83 | elif activation_type == 'leaky': 84 | activation = lambda x: torch.nn.functional.leaky_relu(x, negative_slope=0.2) 85 | elif activation_type == 'sin': 86 | activation = lambda x: torch.sin(x) 87 | elif activation_type == 'sin_residual': 88 | activation = lambda x: torch.sin(x) + (x/2) 89 | elif activation_type == 'relu_sin': 90 | activation = lambda x: torch.sin(torch.relu(x)) + torch.relu(x/2) 91 | elif activation_type == 'norm_exp': 92 | activation = NormalizedExp() 93 | elif activation_type == 'sign_gelu2': 94 | activation = SignGelu2(neg_scale=20.0) 95 | elif activation_type == 'sinlu': 96 | activation = SinLU(dim=dim) 97 | elif activation_type == 'relu_fwd_silu_bwd': 98 | activation = relu_fwd_silu_bwd 99 | 100 | if power != 1.0: 101 | self.activation = lambda x: torch.pow(activation(x), power) 102 | else: 103 | self.activation = activation 104 | 105 | def forward(self, x): 106 | return self.activation(x) -------------------------------------------------------------------------------- /common/cifar_utils.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | '''Some helper functions for PyTorch, including: 4 | - get_mean_and_std: calculate the mean and std value of dataset. 5 | - msr_init: net parameter initialization. 6 | - progress_bar: progress bar mimic xlua.progress. 7 | ''' 8 | import os 9 | import sys 10 | import time 11 | import math 12 | 13 | import torch.nn as nn 14 | import torch.nn.init as init 15 | import torch 16 | import torchvision 17 | import torchvision.transforms as transforms 18 | from .randomaug import RandAugment 19 | 20 | def get_mean_and_std(dataset): 21 | '''Compute the mean and std value of dataset.''' 22 | dataloader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=True, num_workers=2) 23 | mean = torch.zeros(3) 24 | std = torch.zeros(3) 25 | print('==> Computing mean and std..') 26 | for inputs, targets in dataloader: 27 | for i in range(3): 28 | mean[i] += inputs[:,i,:,:].mean() 29 | std[i] += inputs[:,i,:,:].std() 30 | mean.div_(len(dataset)) 31 | std.div_(len(dataset)) 32 | return mean, std 33 | 34 | def init_params(net): 35 | '''Init layer parameters.''' 36 | for m in net.modules(): 37 | if isinstance(m, nn.Conv2d): 38 | init.kaiming_normal(m.weight, mode='fan_out') 39 | if m.bias: 40 | init.constant(m.bias, 0) 41 | elif isinstance(m, nn.BatchNorm2d): 42 | init.constant(m.weight, 1) 43 | init.constant(m.bias, 0) 44 | elif isinstance(m, nn.Linear): 45 | init.normal(m.weight, std=1e-3) 46 | if m.bias: 47 | init.constant(m.bias, 0) 48 | 49 | 50 | try: 51 | _, term_width = os.popen('stty size', 'r').read().split() 52 | except: 53 | term_width = 80 54 | term_width = int(term_width) 55 | 56 | TOTAL_BAR_LENGTH = 65. 57 | last_time = time.time() 58 | begin_time = last_time 59 | def progress_bar(current, total, msg=None): 60 | global last_time, begin_time 61 | if current == 0: 62 | begin_time = time.time() # Reset for new bar. 63 | 64 | cur_len = int(TOTAL_BAR_LENGTH*current/total) 65 | rest_len = int(TOTAL_BAR_LENGTH - cur_len) - 1 66 | 67 | sys.stdout.write(' [') 68 | for i in range(cur_len): 69 | sys.stdout.write('=') 70 | sys.stdout.write('>') 71 | for i in range(rest_len): 72 | sys.stdout.write('.') 73 | sys.stdout.write(']') 74 | 75 | cur_time = time.time() 76 | step_time = cur_time - last_time 77 | last_time = cur_time 78 | tot_time = cur_time - begin_time 79 | 80 | L = [] 81 | L.append(' Step: %s' % format_time(step_time)) 82 | L.append(' | Tot: %s' % format_time(tot_time)) 83 | if msg: 84 | L.append(' | ' + msg) 85 | 86 | msg = ''.join(L) 87 | sys.stdout.write(msg) 88 | for i in range(term_width-int(TOTAL_BAR_LENGTH)-len(msg)-3): 89 | sys.stdout.write(' ') 90 | 91 | # Go back to the center of the bar. 92 | for i in range(term_width-int(TOTAL_BAR_LENGTH/2)+2): 93 | sys.stdout.write('\b') 94 | sys.stdout.write(' %d/%d ' % (current+1, total)) 95 | 96 | if current < total-1: 97 | sys.stdout.write('\r') 98 | else: 99 | sys.stdout.write('\n') 100 | sys.stdout.flush() 101 | 102 | def format_time(seconds): 103 | days = int(seconds / 3600/24) 104 | seconds = seconds - days*3600*24 105 | hours = int(seconds / 3600) 106 | seconds = seconds - hours*3600 107 | minutes = int(seconds / 60) 108 | seconds = seconds - minutes*60 109 | secondsf = int(seconds) 110 | seconds = seconds - secondsf 111 | millis = int(seconds*1000) 112 | 113 | f = '' 114 | i = 1 115 | if days > 0: 116 | f += str(days) + 'D' 117 | i += 1 118 | if hours > 0 and i <= 2: 119 | f += str(hours) + 'h' 120 | i += 1 121 | if minutes > 0 and i <= 2: 122 | f += str(minutes) + 'm' 123 | i += 1 124 | if secondsf > 0 and i <= 2: 125 | f += str(secondsf) + 's' 126 | i += 1 127 | if millis > 0 and i <= 2: 128 | f += str(millis) + 'ms' 129 | i += 1 130 | if f == '': 131 | f = '0ms' 132 | return f 133 | 134 | 135 | def load_data(args): 136 | # Data 137 | print('==> Preparing data..') 138 | transform_train = transforms.Compose([ 139 | transforms.RandomCrop(32, padding=4), 140 | transforms.Resize(args.size), 141 | transforms.RandomHorizontalFlip(), 142 | transforms.ToTensor(), 143 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), 144 | ]) 145 | 146 | transform_test = transforms.Compose([ 147 | transforms.Resize(args.size), 148 | transforms.ToTensor(), 149 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), 150 | ]) 151 | 152 | # Add RandAugment with N, M(hyperparameter) 153 | if args.aug: 154 | transform_train.transforms.insert(0, RandAugment(2, 14)) 155 | 156 | # Prepare dataset 157 | trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train) 158 | trainloader = torch.utils.data.DataLoader(trainset, batch_size=args.bs, shuffle=True, num_workers=8) 159 | 160 | testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test) 161 | testloader = torch.utils.data.DataLoader(testset, batch_size=100, shuffle=False, num_workers=8) 162 | 163 | #classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck') 164 | 165 | return trainloader, testloader 166 | 167 | 168 | 169 | def train(args, epoch, net, net_forward, trainloader, optimizer, scaler, loss_fn=None, optimizer_callback=None): 170 | print('\nEpoch: %d' % epoch) 171 | net.train() 172 | train_loss = 0 173 | correct = 0 174 | total = 0 175 | mp_dtype = torch.float32 176 | if args.mp_dtype=="bf16": 177 | mp_dtype = torch.bfloat16 178 | elif args.mp_dtype=="fp16": 179 | mp_dtype = torch.float16 180 | for batch_idx, (inputs, targets) in enumerate(trainloader): 181 | inputs, targets = inputs.to(args.device), targets.to(args.device) 182 | # Train with amp 183 | with torch.cuda.amp.autocast(enabled=args.mp_dtype!="fp32", dtype=mp_dtype): 184 | loss, preds = loss_fn(net_forward, inputs, targets) 185 | scaler.scale(loss).backward() 186 | scaler.step(optimizer) 187 | scaler.update() 188 | optimizer.zero_grad(set_to_none=True) 189 | if optimizer_callback is not None: 190 | optimizer_callback(optimizer) 191 | 192 | train_loss += loss.item() 193 | _, predicted = preds.max(1) 194 | total += targets.size(0) 195 | correct += predicted.eq(targets).sum().item() 196 | 197 | progress_bar(batch_idx, len(trainloader), 'Loss: %.3f | Acc: %.3f%% (%d/%d)' 198 | % (train_loss/(batch_idx+1), 100.*correct/total, correct, total)) 199 | return train_loss/(batch_idx+1) 200 | 201 | ##### Validation 202 | def test(args, epoch, net, net_forward, testloader, optimizer, scaler): 203 | net.eval() 204 | test_loss = 0 205 | correct = 0 206 | total = 0 207 | with torch.no_grad(): 208 | for batch_idx, (inputs, targets) in enumerate(testloader): 209 | inputs, targets = inputs.to(args.device), targets.to(args.device) 210 | pred_labels = net_forward(inputs) 211 | loss = nn.CrossEntropyLoss()(pred_labels, targets) 212 | 213 | test_loss += loss.item() 214 | _, predicted = pred_labels.max(1) 215 | total += targets.size(0) 216 | correct += predicted.eq(targets).sum().item() 217 | 218 | progress_bar(batch_idx, len(testloader), 'Loss: %.3f | Acc: %.3f%% (%d/%d)' 219 | % (test_loss/(batch_idx+1), 100.*correct/total, correct, total)) 220 | 221 | # Save checkpoint. 222 | acc = 100.*correct/total 223 | if acc > args.best_acc: 224 | print('Saving..') 225 | state = {"model": net.state_dict(), 226 | "optimizer": optimizer.state_dict(), 227 | "scaler": scaler.state_dict()} 228 | if not os.path.isdir('checkpoint'): 229 | os.mkdir('checkpoint') 230 | torch.save(state, './checkpoint/'+args.net+'-{}-ckpt.t7'.format(args.patch)) 231 | args.best_acc = acc 232 | 233 | os.makedirs("log", exist_ok=True) 234 | content = time.ctime() + ' ' + f'Epoch {epoch}, lr: {optimizer.param_groups[0]["lr"]:.7f}, val loss: {test_loss:.5f}, acc: {(acc):.5f}' 235 | print(content) 236 | with open(f'log/log_{args.net}_patch{args.patch}.txt', 'a') as appender: 237 | appender.write(content + "\n") 238 | return test_loss, acc -------------------------------------------------------------------------------- /common/randomaug.py: -------------------------------------------------------------------------------- 1 | # code in this file is adpated from rpmcruz/autoaugment 2 | # https://github.com/rpmcruz/autoaugment/blob/master/transformations.py 3 | import random 4 | 5 | import PIL, PIL.ImageOps, PIL.ImageEnhance, PIL.ImageDraw 6 | import numpy as np 7 | import torch 8 | from PIL import Image 9 | 10 | 11 | def ShearX(img, v): # [-0.3, 0.3] 12 | assert -0.3 <= v <= 0.3 13 | if random.random() > 0.5: 14 | v = -v 15 | return img.transform(img.size, PIL.Image.AFFINE, (1, v, 0, 0, 1, 0)) 16 | 17 | 18 | def ShearY(img, v): # [-0.3, 0.3] 19 | assert -0.3 <= v <= 0.3 20 | if random.random() > 0.5: 21 | v = -v 22 | return img.transform(img.size, PIL.Image.AFFINE, (1, 0, 0, v, 1, 0)) 23 | 24 | 25 | def TranslateX(img, v): # [-150, 150] => percentage: [-0.45, 0.45] 26 | assert -0.45 <= v <= 0.45 27 | if random.random() > 0.5: 28 | v = -v 29 | v = v * img.size[0] 30 | return img.transform(img.size, PIL.Image.AFFINE, (1, 0, v, 0, 1, 0)) 31 | 32 | 33 | def TranslateXabs(img, v): # [-150, 150] => percentage: [-0.45, 0.45] 34 | assert 0 <= v 35 | if random.random() > 0.5: 36 | v = -v 37 | return img.transform(img.size, PIL.Image.AFFINE, (1, 0, v, 0, 1, 0)) 38 | 39 | 40 | def TranslateY(img, v): # [-150, 150] => percentage: [-0.45, 0.45] 41 | assert -0.45 <= v <= 0.45 42 | if random.random() > 0.5: 43 | v = -v 44 | v = v * img.size[1] 45 | return img.transform(img.size, PIL.Image.AFFINE, (1, 0, 0, 0, 1, v)) 46 | 47 | 48 | def TranslateYabs(img, v): # [-150, 150] => percentage: [-0.45, 0.45] 49 | assert 0 <= v 50 | if random.random() > 0.5: 51 | v = -v 52 | return img.transform(img.size, PIL.Image.AFFINE, (1, 0, 0, 0, 1, v)) 53 | 54 | 55 | def Rotate(img, v): # [-30, 30] 56 | assert -30 <= v <= 30 57 | if random.random() > 0.5: 58 | v = -v 59 | return img.rotate(v) 60 | 61 | 62 | def AutoContrast(img, _): 63 | return PIL.ImageOps.autocontrast(img) 64 | 65 | 66 | def Invert(img, _): 67 | return PIL.ImageOps.invert(img) 68 | 69 | 70 | def Equalize(img, _): 71 | return PIL.ImageOps.equalize(img) 72 | 73 | 74 | def Flip(img, _): # not from the paper 75 | return PIL.ImageOps.mirror(img) 76 | 77 | 78 | def Solarize(img, v): # [0, 256] 79 | assert 0 <= v <= 256 80 | return PIL.ImageOps.solarize(img, v) 81 | 82 | 83 | def SolarizeAdd(img, addition=0, threshold=128): 84 | img_np = np.array(img).astype(int) 85 | img_np = img_np + addition 86 | img_np = np.clip(img_np, 0, 255) 87 | img_np = img_np.astype(np.uint8) 88 | img = Image.fromarray(img_np) 89 | return PIL.ImageOps.solarize(img, threshold) 90 | 91 | 92 | def Posterize(img, v): # [4, 8] 93 | v = int(v) 94 | v = max(1, v) 95 | return PIL.ImageOps.posterize(img, v) 96 | 97 | 98 | def Contrast(img, v): # [0.1,1.9] 99 | assert 0.1 <= v <= 1.9 100 | return PIL.ImageEnhance.Contrast(img).enhance(v) 101 | 102 | 103 | def Color(img, v): # [0.1,1.9] 104 | assert 0.1 <= v <= 1.9 105 | return PIL.ImageEnhance.Color(img).enhance(v) 106 | 107 | 108 | def Brightness(img, v): # [0.1,1.9] 109 | assert 0.1 <= v <= 1.9 110 | return PIL.ImageEnhance.Brightness(img).enhance(v) 111 | 112 | 113 | def Sharpness(img, v): # [0.1,1.9] 114 | assert 0.1 <= v <= 1.9 115 | return PIL.ImageEnhance.Sharpness(img).enhance(v) 116 | 117 | 118 | def Cutout(img, v): # [0, 60] => percentage: [0, 0.2] 119 | assert 0.0 <= v <= 0.2 120 | if v <= 0.: 121 | return img 122 | 123 | v = v * img.size[0] 124 | return CutoutAbs(img, v) 125 | 126 | 127 | def CutoutAbs(img, v): # [0, 60] => percentage: [0, 0.2] 128 | # assert 0 <= v <= 20 129 | if v < 0: 130 | return img 131 | w, h = img.size 132 | x0 = np.random.uniform(w) 133 | y0 = np.random.uniform(h) 134 | 135 | x0 = int(max(0, x0 - v / 2.)) 136 | y0 = int(max(0, y0 - v / 2.)) 137 | x1 = min(w, x0 + v) 138 | y1 = min(h, y0 + v) 139 | 140 | xy = (x0, y0, x1, y1) 141 | color = (125, 123, 114) 142 | # color = (0, 0, 0) 143 | img = img.copy() 144 | PIL.ImageDraw.Draw(img).rectangle(xy, color) 145 | return img 146 | 147 | 148 | def SamplePairing(imgs): # [0, 0.4] 149 | def f(img1, v): 150 | i = np.random.choice(len(imgs)) 151 | img2 = PIL.Image.fromarray(imgs[i]) 152 | return PIL.Image.blend(img1, img2, v) 153 | 154 | return f 155 | 156 | 157 | def Identity(img, v): 158 | return img 159 | 160 | 161 | def augment_list(): # 16 oeprations and their ranges 162 | # https://github.com/google-research/uda/blob/master/image/randaugment/policies.py#L57 163 | # l = [ 164 | # (Identity, 0., 1.0), 165 | # (ShearX, 0., 0.3), # 0 166 | # (ShearY, 0., 0.3), # 1 167 | # (TranslateX, 0., 0.33), # 2 168 | # (TranslateY, 0., 0.33), # 3 169 | # (Rotate, 0, 30), # 4 170 | # (AutoContrast, 0, 1), # 5 171 | # (Invert, 0, 1), # 6 172 | # (Equalize, 0, 1), # 7 173 | # (Solarize, 0, 110), # 8 174 | # (Posterize, 4, 8), # 9 175 | # # (Contrast, 0.1, 1.9), # 10 176 | # (Color, 0.1, 1.9), # 11 177 | # (Brightness, 0.1, 1.9), # 12 178 | # (Sharpness, 0.1, 1.9), # 13 179 | # # (Cutout, 0, 0.2), # 14 180 | # # (SamplePairing(imgs), 0, 0.4), # 15 181 | # ] 182 | 183 | # https://github.com/tensorflow/tpu/blob/8462d083dd89489a79e3200bcc8d4063bf362186/models/official/efficientnet/autoaugment.py#L505 184 | l = [ 185 | (AutoContrast, 0, 1), 186 | (Equalize, 0, 1), 187 | (Invert, 0, 1), 188 | (Rotate, 0, 30), 189 | (Posterize, 0, 4), 190 | (Solarize, 0, 256), 191 | (SolarizeAdd, 0, 110), 192 | (Color, 0.1, 1.9), 193 | (Contrast, 0.1, 1.9), 194 | (Brightness, 0.1, 1.9), 195 | (Sharpness, 0.1, 1.9), 196 | (ShearX, 0., 0.3), 197 | (ShearY, 0., 0.3), 198 | (CutoutAbs, 0, 40), 199 | (TranslateXabs, 0., 100), 200 | (TranslateYabs, 0., 100), 201 | ] 202 | 203 | return l 204 | 205 | 206 | class Lighting(object): 207 | """Lighting noise(AlexNet - style PCA - based noise)""" 208 | 209 | def __init__(self, alphastd, eigval, eigvec): 210 | self.alphastd = alphastd 211 | self.eigval = torch.Tensor(eigval) 212 | self.eigvec = torch.Tensor(eigvec) 213 | 214 | def __call__(self, img): 215 | if self.alphastd == 0: 216 | return img 217 | 218 | alpha = img.new().resize_(3).normal_(0, self.alphastd) 219 | rgb = self.eigvec.type_as(img).clone() \ 220 | .mul(alpha.view(1, 3).expand(3, 3)) \ 221 | .mul(self.eigval.view(1, 3).expand(3, 3)) \ 222 | .sum(1).squeeze() 223 | 224 | return img.add(rgb.view(3, 1, 1).expand_as(img)) 225 | 226 | 227 | class CutoutDefault(object): 228 | """ 229 | Reference : https://github.com/quark0/darts/blob/master/cnn/utils.py 230 | """ 231 | def __init__(self, length): 232 | self.length = length 233 | 234 | def __call__(self, img): 235 | h, w = img.size(1), img.size(2) 236 | mask = np.ones((h, w), np.float32) 237 | y = np.random.randint(h) 238 | x = np.random.randint(w) 239 | 240 | y1 = np.clip(y - self.length // 2, 0, h) 241 | y2 = np.clip(y + self.length // 2, 0, h) 242 | x1 = np.clip(x - self.length // 2, 0, w) 243 | x2 = np.clip(x + self.length // 2, 0, w) 244 | 245 | mask[y1: y2, x1: x2] = 0. 246 | mask = torch.from_numpy(mask) 247 | mask = mask.expand_as(img) 248 | img *= mask 249 | return img 250 | 251 | 252 | class RandAugment: 253 | def __init__(self, n, m): 254 | self.n = n 255 | self.m = m # [0, 30] 256 | self.augment_list = augment_list() 257 | 258 | def __call__(self, img): 259 | ops = random.choices(self.augment_list, k=self.n) 260 | for op, minval, maxval in ops: 261 | val = (float(self.m) / 30) * float(maxval - minval) + minval 262 | img = op(img, val) 263 | 264 | return img 265 | -------------------------------------------------------------------------------- /common/vit.py: -------------------------------------------------------------------------------- 1 | # https://github.com/lucidrains/vit-pytorch/blob/main/vit_pytorch/vit.py 2 | 3 | import torch 4 | from torch import nn 5 | 6 | from einops import rearrange, repeat 7 | from einops.layers.torch import Rearrange 8 | 9 | # helpers 10 | 11 | def pair(t): 12 | return t if isinstance(t, tuple) else (t, t) 13 | 14 | 15 | class FeedForward(nn.Module): 16 | def __init__(self, dim, hidden_dim, dropout = 0.): 17 | super().__init__() 18 | self.net = nn.Sequential( 19 | nn.Linear(dim, hidden_dim), 20 | nn.GELU(), 21 | nn.Dropout(dropout), 22 | nn.Linear(hidden_dim, dim), 23 | nn.Dropout(dropout) 24 | ) 25 | def forward(self, x): 26 | return self.net(x) 27 | 28 | class Attention(nn.Module): 29 | def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.): 30 | super().__init__() 31 | inner_dim = dim_head * heads 32 | self.heads = heads 33 | self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False) 34 | self.to_out = nn.Sequential( 35 | nn.Linear(inner_dim, dim), 36 | nn.Dropout(dropout) 37 | ) 38 | 39 | def forward(self, x): 40 | q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), (self.to_qkv(x).chunk(3, dim = -1))) 41 | out = torch.nn.functional.scaled_dot_product_attention(q, k, v) 42 | out = rearrange(out, 'b h n d -> b n (h d)') 43 | return self.to_out(out) 44 | 45 | class Block(nn.Module): 46 | 47 | def __init__(self, dim, heads, dim_head, mlp_dim, dropout = 0.): 48 | super().__init__() 49 | self.attn_norm = nn.LayerNorm(dim) 50 | self.ff_norm = nn.LayerNorm(dim) 51 | self.attn = Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout) 52 | self.ff = FeedForward(dim, mlp_dim, dropout) 53 | 54 | def forward(self, x): 55 | x = self.attn(self.attn_norm(x)) + x 56 | x = self.ff(self.ff_norm(x)) + x 57 | return x 58 | 59 | class Transformer(nn.Module): 60 | def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.): 61 | super().__init__() 62 | self.blocks = nn.ModuleList([Block(dim, heads, dim_head, mlp_dim, dropout) for i in range(depth)]) 63 | 64 | def forward(self, x): 65 | for block in self.blocks: 66 | x = block(x) 67 | return x 68 | 69 | class ViT(nn.Module): 70 | def __init__(self, *, 71 | dim=512, 72 | depth=6, 73 | heads=8, 74 | mlp_dim=512, 75 | image_size=32, 76 | patch_size=4, 77 | num_classes=10, 78 | channels = 3, 79 | dim_head = 64, 80 | dropout = 0., 81 | emb_dropout = 0., 82 | ): 83 | super().__init__() 84 | image_height, image_width = pair(image_size) 85 | patch_height, patch_width = pair(patch_size) 86 | mlp_dim = dim * 4 87 | 88 | assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.' 89 | 90 | num_patches = (image_height // patch_height) * (image_width // patch_width) 91 | patch_dim = channels * patch_height * patch_width 92 | 93 | self.to_patch_embedding = nn.Sequential( 94 | Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_height, p2 = patch_width), 95 | nn.Linear(patch_dim, dim), 96 | ) 97 | 98 | self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim)) 99 | self.cls_token = nn.Parameter(torch.randn(1, 1, dim) * 0.02) 100 | self.dropout = nn.Dropout(emb_dropout) 101 | 102 | self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout) 103 | 104 | self.mlp_head = nn.Sequential( 105 | nn.LayerNorm(dim), 106 | nn.Linear(dim, num_classes) 107 | ) 108 | 109 | self.init_weights() 110 | 111 | def init_weights(self): 112 | for m in self.modules(): 113 | if isinstance(m, nn.Linear): 114 | nn.init.xavier_normal_(m.weight) 115 | if m.bias is not None: 116 | nn.init.constant_(m.bias, 0) 117 | elif isinstance(m, nn.LayerNorm): 118 | nn.init.constant_(m.bias, 0) 119 | nn.init.constant_(m.weight, 1.0) 120 | 121 | def get_feat(self, img): 122 | x = self.to_patch_embedding(img) 123 | b, n, _ = x.shape 124 | 125 | cls_tokens = repeat(self.cls_token, '() n d -> b n d', b = b) 126 | x = torch.cat((cls_tokens, x), dim=1) 127 | x += self.pos_embedding[:, :(n + 1)] 128 | x = self.dropout(x) 129 | x = self.transformer(x)[:, 0] 130 | return x 131 | 132 | def forward(self, img): 133 | x = self.get_feat(img) 134 | return self.mlp_head(x) 135 | -------------------------------------------------------------------------------- /dipole_attn/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ethansmith2000/TransformerExperiments/eb77920ba61e42b87c2090fb01b2c94183cfe43b/dipole_attn/__init__.py -------------------------------------------------------------------------------- /dipole_attn/train_cifar10.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | ''' 3 | 4 | Train CIFAR10 with PyTorch and Vision Transformers! 5 | written by @kentaroy47, @arutema47 6 | 7 | ''' 8 | 9 | from __future__ import print_function 10 | 11 | import torch 12 | import torch.nn as nn 13 | import torch.optim as optim 14 | import torch.nn.functional as F 15 | import torch.backends.cudnn as cudnn 16 | import numpy as np 17 | 18 | import torchvision 19 | import torchvision.transforms as transforms 20 | 21 | import os 22 | from types import SimpleNamespace 23 | import pandas as pds 24 | import csv 25 | import time 26 | 27 | from utils import progress_bar, load_data, load_model, train, test 28 | from randomaug import RandAugment 29 | 30 | 31 | default_args = dict( 32 | lr = 1e-4, 33 | opt = "adam", 34 | resume = False, 35 | aug = True, 36 | mp_dtype = "bf16", 37 | wandb = True, 38 | mixup = True, 39 | net = "vit", 40 | bs = 512, 41 | size = 32, 42 | n_epochs = 100, 43 | patch = 4, 44 | dim = 64, 45 | convkernel = 8, 46 | num_classes=10, 47 | mlp_dim = 1024, 48 | compile=False, 49 | ) 50 | 51 | 52 | def train_model(args): 53 | def loss_fn(net_fwd, inputs, targets): 54 | pred_labels = net_forward(inputs) 55 | loss = nn.CrossEntropyLoss()(pred_labels, targets) 56 | return loss, pred_labels 57 | 58 | 59 | args = SimpleNamespace(**args) 60 | args.device = 'cuda' if torch.cuda.is_available() else 'cpu' 61 | args.best_acc = 0 # best test accuracy 62 | args.start_epoch = 0 # start from epoch 0 or last checkpoint epoch 63 | 64 | trainloader, testloader = load_data(args) 65 | net, net_forward = load_model(args) 66 | 67 | print("NUM PARAMS: ", sum([p.numel() for p in net.parameters()])) 68 | 69 | if args.resume: 70 | # Load checkpoint. 71 | print('==> Resuming from checkpoint..') 72 | assert os.path.isdir('checkpoint'), 'Error: no checkpoint directory found!' 73 | checkpoint = torch.load('./checkpoint/{}-ckpt.t7'.format(args.net)) 74 | net.load_state_dict(checkpoint['net']) 75 | args.best_acc = checkpoint['acc'] 76 | args.start_epoch = checkpoint['epoch'] 77 | 78 | optimizer = optim.AdamW(net.parameters(), lr=args.lr) 79 | scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, args.n_epochs) 80 | 81 | ##### Training 82 | scaler = torch.cuda.amp.GradScaler(enabled=args.mp_dtype == "float16") 83 | list_loss = [] 84 | list_acc = [] 85 | 86 | if args.wandb: 87 | import wandb 88 | watermark = "run" 89 | wandb.init(project="cifar10-challange", name=watermark) 90 | wandb.config.update(args) 91 | 92 | if args.wandb: 93 | wandb.watch(net) 94 | 95 | for epoch in range(args.start_epoch, args.n_epochs): 96 | start = time.time() 97 | trainloss = train(args, epoch, net, net_forward, trainloader, optimizer, scaler, loss_fn=loss_fn) 98 | val_loss, acc = test(args, epoch, net, net_forward, testloader, optimizer, scaler) 99 | 100 | scheduler.step(epoch-1) # step cosine scheduling 101 | 102 | list_loss.append(val_loss) 103 | list_acc.append(acc) 104 | 105 | # Log training.. 106 | if args.wandb: 107 | wandb.log({'epoch': epoch, 'train_loss': trainloss, 'val_loss': val_loss, "val_acc": acc, "lr": optimizer.param_groups[0]["lr"], 108 | "epoch_time": time.time()-start}) 109 | 110 | # Write out csv.. 111 | with open(f'log/log_{args.net}_patch{args.patch}.csv', 'w') as f: 112 | writer = csv.writer(f, lineterminator='\n') 113 | writer.writerow(list_loss) 114 | writer.writerow(list_acc) 115 | print(list_loss) 116 | 117 | # writeout wandb 118 | if args.wandb: 119 | wandb.save("wandb_{}.h5".format(args.net)) 120 | wandb.finish() 121 | 122 | if __name__ == '__main__': 123 | args = default_args 124 | train_model(args) 125 | 126 | -------------------------------------------------------------------------------- /hf_gpt_blocks/hf_gpt.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from typing import Optional, Tuple, Union 3 | from transformers.models.gpt2.modeling_gpt2 import BaseModelOutputWithPastAndCrossAttentions, CausalLMOutputWithCrossAttentions, GPT2Block 4 | from types import MethodType 5 | from torch import nn 6 | from einops import rearrange 7 | from torch.nn.functional import scaled_dot_product_attention as sdpa 8 | 9 | 10 | # GPT2Block 11 | def gpt2_block_forward( 12 | self, 13 | hidden_states: Optional[Tuple[torch.FloatTensor]], 14 | layer_past: Optional[Tuple[torch.Tensor]] = None, 15 | attention_mask: Optional[torch.FloatTensor] = None, 16 | head_mask: Optional[torch.FloatTensor] = None, 17 | encoder_hidden_states: Optional[torch.Tensor] = None, 18 | encoder_attention_mask: Optional[torch.FloatTensor] = None, 19 | use_cache: Optional[bool] = False, 20 | output_attentions: Optional[bool] = False, 21 | ) -> Union[Tuple[torch.Tensor], Optional[Tuple[torch.Tensor, Tuple[torch.FloatTensor, ...]]]]: 22 | residual = hidden_states 23 | hidden_states = self.ln_1(hidden_states) 24 | attn_outputs = self.attn( 25 | hidden_states, 26 | layer_past=layer_past, 27 | attention_mask=attention_mask, 28 | head_mask=head_mask, 29 | use_cache=use_cache, 30 | output_attentions=output_attentions, 31 | ) 32 | attn_output = attn_outputs[0] # output_attn: a, present, (attentions) 33 | outputs = attn_outputs[1:] 34 | # residual connection 35 | hidden_states = attn_output + residual 36 | 37 | residual = hidden_states 38 | hidden_states = self.ln_2(hidden_states) 39 | feed_forward_hidden_states = self.mlp(hidden_states) 40 | # residual connection 41 | hidden_states = residual + feed_forward_hidden_states 42 | 43 | # if use_cache: 44 | # outputs = (hidden_states,) + outputs 45 | # else: 46 | outputs = (hidden_states,) + outputs[1:] 47 | 48 | return outputs # hidden_states, present, (attentions, cross_attentions) 49 | 50 | 51 | 52 | 53 | # GPT2Model 54 | def gpt_model_forward( 55 | self, 56 | input_ids: Optional[torch.LongTensor] = None, 57 | past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, 58 | attention_mask: Optional[torch.FloatTensor] = None, 59 | token_type_ids: Optional[torch.LongTensor] = None, 60 | position_ids: Optional[torch.LongTensor] = None, 61 | head_mask: Optional[torch.FloatTensor] = None, 62 | inputs_embeds: Optional[torch.FloatTensor] = None, 63 | encoder_hidden_states: Optional[torch.Tensor] = None, 64 | encoder_attention_mask: Optional[torch.FloatTensor] = None, 65 | use_cache: Optional[bool] = None, 66 | output_attentions: Optional[bool] = None, 67 | output_hidden_states: Optional[bool] = None, 68 | return_dict: Optional[bool] = None, 69 | ): 70 | 71 | input_shape = input_ids.size() 72 | input_ids = input_ids.view(-1, input_shape[-1]) 73 | batch_size = input_ids.shape[0] 74 | device = input_ids.device 75 | 76 | 77 | past_length = 0 78 | past_key_values = tuple([None] * len(self.h)) 79 | 80 | position_ids = torch.arange(past_length, input_shape[-1] + past_length, dtype=torch.long, device=device) 81 | position_ids = position_ids.unsqueeze(0) 82 | 83 | position_embeds = self.wpe(position_ids) 84 | inputs_embeds = self.wte(input_ids) 85 | hidden_states = inputs_embeds + position_embeds.to(inputs_embeds.device) 86 | 87 | # Attention mask. 88 | attention_mask = None 89 | 90 | # Prepare head mask if needed 91 | # 1.0 in head_mask indicate we keep the head 92 | # attention_probs has shape bsz x n_heads x N x N 93 | # head_mask has shape n_layer x batch x n_heads x N x N 94 | head_mask = self.get_head_mask(head_mask, self.config.n_layer) 95 | 96 | hidden_states = self.drop(hidden_states) 97 | 98 | output_shape = (-1,) + input_shape[1:] + (hidden_states.size(-1),) 99 | 100 | use_cache = False 101 | 102 | presents = None 103 | all_self_attentions = None 104 | all_cross_attentions = None 105 | all_hidden_states = None 106 | for i in range(len(self.h)): 107 | block, layer_past = self.h[i], past_key_values[i] 108 | 109 | if self.gradient_checkpointing and self.training: 110 | outputs = self._gradient_checkpointing_func( 111 | block.__call__, 112 | hidden_states, 113 | None, 114 | attention_mask, 115 | head_mask[i], 116 | encoder_hidden_states, 117 | encoder_attention_mask, 118 | use_cache, 119 | output_attentions, 120 | ) 121 | else: 122 | outputs = block( 123 | hidden_states, 124 | layer_past=layer_past, 125 | attention_mask=attention_mask, 126 | head_mask=head_mask[i], 127 | encoder_hidden_states=encoder_hidden_states, 128 | encoder_attention_mask=encoder_attention_mask, 129 | use_cache=use_cache, 130 | output_attentions=output_attentions, 131 | ) 132 | 133 | hidden_states = outputs[0] 134 | 135 | hidden_states = self.ln_f(hidden_states) 136 | 137 | hidden_states = hidden_states.view(output_shape) 138 | 139 | if not return_dict: 140 | return tuple( 141 | v 142 | for v in [hidden_states, presents, all_hidden_states, all_self_attentions, all_cross_attentions] 143 | if v is not None 144 | ) 145 | 146 | return BaseModelOutputWithPastAndCrossAttentions( 147 | last_hidden_state=hidden_states, 148 | past_key_values=presents, 149 | hidden_states=all_hidden_states, 150 | attentions=all_self_attentions, 151 | cross_attentions=all_cross_attentions, 152 | ) 153 | 154 | 155 | 156 | # GPT2LMHeadModel 157 | def gpt_lm_head_model_forward( 158 | self, 159 | input_ids: Optional[torch.LongTensor] = None, 160 | past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, 161 | attention_mask: Optional[torch.FloatTensor] = None, 162 | token_type_ids: Optional[torch.LongTensor] = None, 163 | position_ids: Optional[torch.LongTensor] = None, 164 | head_mask: Optional[torch.FloatTensor] = None, 165 | inputs_embeds: Optional[torch.FloatTensor] = None, 166 | encoder_hidden_states: Optional[torch.Tensor] = None, 167 | encoder_attention_mask: Optional[torch.FloatTensor] = None, 168 | labels: Optional[torch.LongTensor] = None, 169 | use_cache: Optional[bool] = None, 170 | output_attentions: Optional[bool] = None, 171 | output_hidden_states: Optional[bool] = None, 172 | return_dict: Optional[bool] = None, 173 | **kwargs, 174 | ): 175 | r""" 176 | labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): 177 | Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set 178 | `labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100` 179 | are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]` 180 | """ 181 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 182 | 183 | transformer_outputs = self.transformer( 184 | input_ids, 185 | past_key_values=past_key_values, 186 | attention_mask=attention_mask, 187 | token_type_ids=token_type_ids, 188 | position_ids=position_ids, 189 | head_mask=head_mask, 190 | inputs_embeds=inputs_embeds, 191 | encoder_hidden_states=encoder_hidden_states, 192 | encoder_attention_mask=encoder_attention_mask, 193 | use_cache=use_cache, 194 | output_attentions=output_attentions, 195 | output_hidden_states=output_hidden_states, 196 | return_dict=return_dict, 197 | ) 198 | hidden_states = transformer_outputs[0] 199 | 200 | lm_logits = self.lm_head(hidden_states) 201 | 202 | loss = None 203 | if labels is not None: 204 | # Flatten the tokens 205 | loss = self.loss_function( 206 | lm_logits, 207 | labels, 208 | vocab_size=self.config.vocab_size, 209 | **kwargs, 210 | ) 211 | 212 | if not return_dict: 213 | output = (lm_logits,) + transformer_outputs[1:] 214 | return ((loss,) + output) if loss is not None else output 215 | 216 | return CausalLMOutputWithCrossAttentions( 217 | loss=loss, 218 | logits=lm_logits, 219 | past_key_values=transformer_outputs.past_key_values, 220 | hidden_states=transformer_outputs.hidden_states, 221 | attentions=transformer_outputs.attentions, 222 | cross_attentions=transformer_outputs.cross_attentions, 223 | ) 224 | 225 | class AttentionBase(nn.Module): 226 | """ 227 | Causal multihead attention that uses torch's SDPA 228 | """ 229 | def __init__(self, dim=512, heads=8): 230 | super().__init__() 231 | self.heads = heads 232 | self.to_qkv = nn.Linear(dim, dim*3, bias=False) 233 | self.to_out = nn.Linear(dim, dim, bias=True) 234 | 235 | def forward(self, x, 236 | layer_past: Optional[Tuple[torch.Tensor]] = None, 237 | attention_mask: Optional[torch.FloatTensor] = None, 238 | head_mask: Optional[torch.FloatTensor] = None, 239 | encoder_hidden_states: Optional[torch.Tensor] = None, 240 | encoder_attention_mask: Optional[torch.FloatTensor] = None, 241 | use_cache: Optional[bool] = False, 242 | output_attentions: Optional[bool] = False, 243 | **kwargs,): 244 | 245 | b, n, d, h = (*x.shape, self.heads) 246 | q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h=h), (self.to_qkv(x).chunk(3, dim=-1))) 247 | outputs = (self.to_out(rearrange(sdpa(q, k, v, is_causal=True), 'b h n d -> b n (h d)')), None) 248 | return outputs 249 | 250 | 251 | def patch_gpt(model): 252 | # patch attentions 253 | for n, m in model.named_modules(): 254 | if hasattr(m, "attn"): 255 | dim = model.config.n_embd 256 | heads = m.attn.num_heads 257 | 258 | m.attn = AttentionBase(dim=dim, heads=heads) 259 | if hasattr(m, "mlp"): 260 | dim = model.config.n_embd 261 | inner_dim = model.config.n_inner if model.config.n_inner is not None else 4 * model.config.n_embd 262 | m.mlp.c_fc = nn.Linear(dim, inner_dim, bias=True) 263 | m.mlp.c_proj = nn.Linear(inner_dim, dim, bias=True) 264 | 265 | 266 | for block in model.transformer.h: 267 | block.forward = MethodType(gpt2_block_forward, block) 268 | model.transformer.forward = MethodType(gpt_model_forward, model.transformer) 269 | model.forward = MethodType(gpt_lm_head_model_forward, model) 270 | 271 | 272 | # init all 2d weights xavier normal 273 | for n, m in model.named_modules(): 274 | if isinstance(m, nn.Linear): 275 | if m.weight.dim() >= 2: 276 | nn.init.xavier_normal_(m.weight) 277 | 278 | # init all biases to 0, including norm params 279 | for n, m in model.named_modules(): 280 | if isinstance(m, nn.Linear): 281 | if m.bias is not None: 282 | nn.init.zeros_(m.bias) 283 | elif isinstance(m, nn.LayerNorm): 284 | nn.init.zeros_(m.bias) 285 | nn.init.ones_(m.weight) 286 | 287 | return model 288 | 289 | -------------------------------------------------------------------------------- /mlp_mods/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ethansmith2000/TransformerExperiments/eb77920ba61e42b87c2090fb01b2c94183cfe43b/mlp_mods/__init__.py -------------------------------------------------------------------------------- /mlp_mods/gpt.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from typing import Optional, Tuple 4 | from common.activations import Activation 5 | from types import MethodType 6 | 7 | 8 | class Conv1D(nn.Module): 9 | """ 10 | 1D-convolutional layer as defined by Radford et al. for OpenAI GPT (and also used in GPT-2). 11 | 12 | Basically works like a linear layer but the weights are transposed. 13 | 14 | Args: 15 | nf (`int`): The number of output features. 16 | nx (`int`): The number of input features. 17 | """ 18 | 19 | def __init__(self, nf, nx): 20 | super().__init__() 21 | self.nf = nf 22 | self.weight = nn.Parameter(torch.empty(nx, nf)) 23 | self.bias = nn.Parameter(torch.zeros(nf)) 24 | nn.init.normal_(self.weight, std=0.02) 25 | 26 | def forward(self, x): 27 | size_out = x.size()[:-1] + (self.nf,) 28 | x = torch.addmm(self.bias, x.view(-1, x.size(-1)), self.weight) 29 | x = x.view(size_out) 30 | return x 31 | 32 | 33 | class GeGLU(nn.Module): 34 | def __init__(self, in_dim=1024, out_dim=1024, activation_type='gelu', power=1.0, norm=False): 35 | super().__init__() 36 | self.fc = Conv1D(out_dim * 2, in_dim) 37 | self.act = Activation(activation_type=activation_type, power=power) 38 | self.norm = torch.nn.LayerNorm(out_dim) if norm else torch.nn.Identity() 39 | 40 | def forward(self, hidden_states: Optional[Tuple[torch.FloatTensor]]) -> torch.FloatTensor: 41 | hidden_states, gate = self.norm(self.fc(hidden_states)).chunk(2, dim=-1) 42 | hidden_states = self.act(gate) * hidden_states 43 | return hidden_states 44 | 45 | 46 | class LinearAct(nn.Module): 47 | def __init__(self, in_dim=1024, out_dim=1024, activation_type='gelu', power=1.0, norm=False): 48 | super().__init__() 49 | self.fc = Conv1D(out_dim, in_dim) 50 | self.act = Activation(activation_type=activation_type, power=power) 51 | self.norm = torch.nn.LayerNorm(out_dim) if norm else torch.nn.Identity() 52 | 53 | def forward(self, hidden_states: Optional[Tuple[torch.FloatTensor]]) -> torch.FloatTensor: 54 | hidden_states = self.act(self.norm(self.fc(hidden_states))) 55 | return hidden_states 56 | 57 | 58 | class GPT2MLP(nn.Module): 59 | """ 60 | MLP layer but allows for varying number of layers and hidden dimensions and inbetween norms if we want 61 | """ 62 | def __init__(self, 63 | embed_dim, 64 | mults=[2,2], 65 | norms=[False,False], 66 | mode = "base", # geglu 67 | activation_type='gelu', 68 | power=1.0 69 | ): 70 | super().__init__() 71 | net = [] 72 | cur_dim = embed_dim 73 | 74 | assert len(mults) == len(norms) 75 | 76 | for i in range(len(mults)): 77 | in_dim = embed_dim if i == 0 else cur_dim 78 | cur_dim = embed_dim * mults[i] 79 | if mode == "geglu": 80 | net.append(GeGLU(in_dim, cur_dim, activation_type, power, norms[i])) 81 | else: 82 | net.append(LinearAct(in_dim, cur_dim, activation_type, power, norms[i])) 83 | net.append(Conv1D(embed_dim, cur_dim)) 84 | self.net = nn.Sequential(*net) 85 | 86 | 87 | def forward(self, hidden_states: Optional[Tuple[torch.FloatTensor]]) -> torch.FloatTensor: 88 | return self.net(hidden_states) 89 | 90 | 91 | 92 | class GPT2MLPMultipleResidual(nn.Module): 93 | 94 | """ 95 | multiple back to back MLPs with residual connections between each 96 | """ 97 | 98 | def __init__(self, 99 | embed_dim, 100 | mults=[[2,2], [1]], 101 | interior_norms=[[False,False],[False]], 102 | exterior_norms=[True, True], 103 | mode = ["base", "geglu"], # geglu 104 | activation_type='gelu', 105 | power=1.0 106 | ): 107 | super().__init__() 108 | self.sub_mlps = nn.ModuleList() 109 | self.norms = nn.ModuleList([ 110 | nn.LayerNorm(embed_dim) if norm else nn.Identity() for norm in exterior_norms 111 | ]) 112 | for i in range(len(mults)): 113 | self.sub_mlps.append(GPT2MLP(embed_dim, mults[i], interior_norms[i], mode[i], activation_type, power)) 114 | 115 | def forward(self, hidden_states) -> torch.FloatTensor: 116 | for mlp, norm in zip(self.sub_mlps, self.norms): 117 | hidden_states = hidden_states + mlp(norm(hidden_states)) 118 | return hidden_states 119 | 120 | 121 | 122 | 123 | def forward( 124 | self, 125 | hidden_states: Optional[Tuple[torch.FloatTensor]], 126 | layer_past: Optional[Tuple[torch.Tensor]] = None, 127 | attention_mask: Optional[torch.FloatTensor] = None, 128 | head_mask: Optional[torch.FloatTensor] = None, 129 | encoder_hidden_states: Optional[torch.Tensor] = None, 130 | encoder_attention_mask: Optional[torch.FloatTensor] = None, 131 | use_cache: Optional[bool] = False, 132 | output_attentions: Optional[bool] = False, 133 | ): 134 | residual = hidden_states 135 | hidden_states = self.ln_1(hidden_states) 136 | attn_outputs = self.attn( 137 | hidden_states, 138 | layer_past=layer_past, 139 | attention_mask=attention_mask, 140 | head_mask=head_mask, 141 | use_cache=use_cache, 142 | output_attentions=output_attentions, 143 | ) 144 | attn_output = attn_outputs[0] # output_attn: a, present, (attentions) 145 | outputs = attn_outputs[1:] 146 | # residual connection 147 | hidden_states = attn_output + residual 148 | 149 | if encoder_hidden_states is not None: 150 | # add one self-attention block for cross-attention 151 | if not hasattr(self, "crossattention"): 152 | raise ValueError( 153 | f"If `encoder_hidden_states` are passed, {self} has to be instantiated with " 154 | "cross-attention layers by setting `config.add_cross_attention=True`" 155 | ) 156 | residual = hidden_states 157 | hidden_states = self.ln_cross_attn(hidden_states) 158 | cross_attn_outputs = self.crossattention( 159 | hidden_states, 160 | attention_mask=attention_mask, 161 | head_mask=head_mask, 162 | encoder_hidden_states=encoder_hidden_states, 163 | encoder_attention_mask=encoder_attention_mask, 164 | output_attentions=output_attentions, 165 | ) 166 | attn_output = cross_attn_outputs[0] 167 | # residual connection 168 | hidden_states = residual + attn_output 169 | outputs = outputs + cross_attn_outputs[2:] # add cross attentions if we output attention weights 170 | 171 | # residual = hidden_states 172 | # hidden_states = self.ln_2(hidden_states) 173 | # feed_forward_hidden_states = self.mlp(hidden_states) 174 | # # residual connection 175 | # hidden_states = residual + feed_forward_hidden_states 176 | 177 | #### 178 | hidden_states = self.mlp(hidden_states) 179 | #### 180 | 181 | if use_cache: 182 | outputs = (hidden_states,) + outputs 183 | else: 184 | outputs = (hidden_states,) + outputs[1:] 185 | 186 | return outputs # hidden_states, present, (attentions, cross_attentions) 187 | 188 | 189 | 190 | def patch_model(model, args, exp_args): 191 | idx = 0 192 | for n,m in model.named_modules(): 193 | if hasattr(m, "mlp"): 194 | 195 | meets_criteria = False 196 | if exp_args["targets"] == "all": 197 | meets_criteria = True 198 | elif exp_args["targets"] == "even": 199 | meets_criteria = idx % 2 == 0 200 | elif exp_args["targets"] == "odd": 201 | meets_criteria = idx % 2 == 1 202 | elif exp_args["targets"] == "first_half": 203 | meets_criteria = idx < len(model.transformer.h) // 2 204 | elif exp_args["targets"] == "second_half": 205 | meets_criteria = idx >= len(model.transformer.h) // 2 206 | 207 | if meets_criteria: 208 | m.mlp = GPT2MLPMultipleResidual( 209 | embed_dim = model.config.hidden_size, 210 | mults = exp_args["mults"], 211 | interior_norms = exp_args["interior_norms"], 212 | exterior_norms = exp_args["exterior_norms"], 213 | mode = exp_args["mode"] 214 | ) 215 | m.forward = MethodType(forward, m) 216 | 217 | idx += 1 218 | return model 219 | 220 | 221 | def get_run_name(args, exp_args): 222 | run_name = "mult_" + str(exp_args["mults"]) + "_in_" + str(exp_args["interior_norms"]) + "_en_" + str(exp_args["exterior_norms"]) + "_m_" + str(exp_args["mode"]) + "_t_" + str(exp_args["targets"]) + "_lr:" + str(args["learning_rate"]) 223 | args["output_dir"] = f"{args['base_output_dir']}/{run_name}" 224 | 225 | return args, run_name 226 | 227 | 228 | extra_args = { 229 | "mults": [[4]], 230 | "interior_norms": [[False]], 231 | "exterior_norms": [True], 232 | "mode": ["base"], 233 | "targets": "all", # all, even, odd, first_half, second_half 234 | 235 | 236 | # "mults": [[4]], 237 | # "interior_norms": [[False]], 238 | # "exterior_norms": [True], 239 | # "mode": ["geglu"], 240 | # "targets": "all", # all, even, odd, first_half, second_half 241 | 242 | 243 | # "mults": [[2,2]], 244 | # "interior_norms": [[False, False]], 245 | # "exterior_norms": [True], 246 | # "mode": ["base"], 247 | # "targets": "all", # all, even, odd, first_half, second_half 248 | 249 | # "mults": [[2,2]], 250 | # "interior_norms": [[False, True]], 251 | # "exterior_norms": [True], 252 | # "mode": ["base"], 253 | # "targets": "all", # all, even, odd, first_half, second_half 254 | 255 | # "mults": [[1,1,1,1]], 256 | # "interior_norms": [[False, False, False, False]], 257 | # "exterior_norms": [True], 258 | # "mode": ["base"], 259 | # "targets": "all", # all, even, odd, first_half, second_half 260 | 261 | # "mults": [[1,1,1,1]], 262 | # "interior_norms": [[False, True, False, True]], 263 | # "exterior_norms": [True], 264 | # "mode": ["base"], 265 | # "targets": "all", # all, even, odd, first_half, second_half 266 | 267 | 268 | # "mults": [[1,1,1,1]], 269 | # "interior_norms": [[False, False, False, False]], 270 | # "exterior_norms": [True], 271 | # "mode": ["geglu"], 272 | # "targets": "all", # all, even, odd, first_half, second_half 273 | 274 | # "mults": [[1,1,1,1]], 275 | # "interior_norms": [[False, True, False, True]], 276 | # "exterior_norms": [True], 277 | # "mode": ["geglu"], 278 | # "targets": "all", # all, even, odd, first_half, second_half 279 | 280 | 281 | # "mults": [[1,1,1,1]], 282 | # "interior_norms": [[False, False, False, False]], 283 | # "exterior_norms": [True], 284 | # "mode": ["geglu"], 285 | # "targets": "all", # all, even, odd, first_half, second_half 286 | 287 | # "mults": [[1,1,1,1]], 288 | # "interior_norms": [[False, True, False, True]], 289 | # "exterior_norms": [True], 290 | # "mode": ["geglu"], 291 | # "targets": "all", # all, even, odd, first_half, second_half 292 | 293 | # "mults": [[2],[2],[2]], 294 | # "interior_norms": [[False],[False],[False]], 295 | # "exterior_norms": [True, True, True], 296 | # "mode": ["base", "base", "base"], 297 | # "targets": "all", # all, even, odd, first_half, second_half 298 | 299 | # "mults": [[1],[1],[1]], 300 | # "interior_norms": [[False],[False],[False]], 301 | # "exterior_norms": [True, True, True], 302 | # "mode": ["base", "base", "base"], 303 | # "targets": "all", # all, even, odd, first_half, second_half 304 | 305 | } 306 | -------------------------------------------------------------------------------- /mlp_mods/vit.py: -------------------------------------------------------------------------------- 1 | # https://github.com/lucidrains/vit-pytorch/blob/main/vit_pytorch/vit.py 2 | 3 | import torch 4 | from torch import nn 5 | 6 | from einops import rearrange, repeat 7 | from einops.layers.torch import Rearrange 8 | 9 | from common.activations import Activation, LinearAct 10 | 11 | class FeedForward(nn.Module): 12 | def __init__(self, dim, hidden_dim, dropout = 0., act="gelu", act_power=1): 13 | super().__init__() 14 | self.net = nn.Sequential( 15 | nn.Linear(dim, hidden_dim), 16 | Activation(act, power=act_power, dim=hidden_dim), 17 | nn.Dropout(dropout), 18 | nn.Linear(hidden_dim, dim), 19 | nn.Dropout(dropout) 20 | ) 21 | def forward(self, x): 22 | return self.net(x) 23 | 24 | 25 | class GegluFeedForward(nn.Module): 26 | def __init__(self, dim, hidden_dim, dropout = 0., act="gelu", act_power=1): 27 | super().__init__() 28 | self.up = nn.Linear(dim, hidden_dim * 2) 29 | self.down = nn.Linear(hidden_dim, dim) 30 | self.act = Activation(act, power=act_power, dim=hidden_dim) 31 | self.dropout = nn.Dropout(dropout) 32 | 33 | def forward(self, x): 34 | x, gate = self.up(x).chunk(2, dim=-1) 35 | gate = self.act(gate) 36 | x = self.dropout(x) 37 | gate = self.dropout(gate) 38 | return self.dropout(self.down(x * gate)) 39 | 40 | class DoubleFeedForward(nn.Module): 41 | def __init__(self, dim, hidden_dim, dropout = 0., act="gelu", act_power=1): 42 | super().__init__() 43 | self.net = nn.Sequential( 44 | nn.Linear(dim, hidden_dim), 45 | Activation(act, power=act_power, dim=hidden_dim), 46 | nn.Dropout(dropout), 47 | nn.Linear(hidden_dim, hidden_dim), 48 | Activation(act, power=act_power, dim=hidden_dim), 49 | nn.Dropout(dropout), 50 | nn.Linear(hidden_dim, dim), 51 | nn.Dropout(dropout) 52 | ) 53 | 54 | def forward(self, x): 55 | return self.net(x) 56 | 57 | 58 | # def make_linear(in_dim, out_dim, bias=True, init_type='xavier', **init_kwargs): 59 | # linear = nn.Linear(in_dim, out_dim, bias=bias) 60 | # if init_type == 'xavier': 61 | # nn.init.xavier_uniform_(linear.weight, **init_kwargs) 62 | # elif init_type == 'kaiming': 63 | # nn.init.kaiming_uniform_(linear.weight, **init_kwargs) 64 | # if bias: 65 | # nn.init.constant_(linear.bias, 0) 66 | # return linear 67 | 68 | 69 | # class GeGLU(nn.Module): 70 | # def __init__(self, in_dim=1024, out_dim=1024, activation_type='gelu', power=1.0, norm=False): 71 | # super().__init__() 72 | # self.fc = make_linear(in_dim, out_dim * 2) 73 | # self.act = Activation(activation_type=activation_type, power=power) 74 | # self.norm = torch.nn.LayerNorm(out_dim) if norm else torch.nn.Identity() 75 | 76 | # def forward(self, hidden_states) -> torch.FloatTensor: 77 | # hidden_states, gate = self.norm(self.fc(hidden_states)).chunk(2, dim=-1) 78 | # hidden_states = self.act(gate) * hidden_states 79 | # return hidden_states 80 | 81 | 82 | # class LinearAct(nn.Module): 83 | # def __init__(self, in_dim=1024, out_dim=1024, activation_type='gelu', power=1.0, norm=False): 84 | # super().__init__() 85 | # self.fc = make_linear(in_dim, out_dim) 86 | # self.act = Activation(activation_type=activation_type, power=power) 87 | # self.norm = torch.nn.LayerNorm(out_dim) if norm else torch.nn.Identity() 88 | 89 | # def forward(self, hidden_states) -> torch.FloatTensor: 90 | # hidden_states = self.act(self.norm(self.fc(hidden_states))) 91 | # return hidden_states 92 | 93 | 94 | # class GPT2MLP(nn.Module): 95 | # """ 96 | # MLP layer but allows for varying number of layers and hidden dimensions and inbetween norms if we want 97 | # """ 98 | # def __init__(self, 99 | # embed_dim, 100 | # mults=[2,2], 101 | # norms=[False,False], 102 | # mode = "base", # geglu 103 | # activation_type='gelu', 104 | # power=1.0 105 | # ): 106 | # super().__init__() 107 | # net = [] 108 | # cur_dim = embed_dim 109 | 110 | # assert len(mults) == len(norms) 111 | 112 | # for i in range(len(mults)): 113 | # in_dim = embed_dim if i == 0 else cur_dim 114 | # cur_dim = embed_dim * mults[i] 115 | # if mode == "geglu": 116 | # net.append(GeGLU(in_dim, cur_dim, activation_type, power, norms[i])) 117 | # else: 118 | # net.append(LinearAct(in_dim, cur_dim, activation_type, power, norms[i])) 119 | # net.append(Conv1D(embed_dim, cur_dim)) 120 | # self.net = nn.Sequential(*net) 121 | 122 | 123 | # def forward(self, hidden_states: Optional[Tuple[torch.FloatTensor]]) -> torch.FloatTensor: 124 | # return self.net(hidden_states) 125 | 126 | 127 | 128 | # class GPT2MLPMultipleResidual(nn.Module): 129 | 130 | # """ 131 | # multiple back to back MLPs with residual connections between each 132 | # """ 133 | 134 | # def __init__(self, 135 | # embed_dim, 136 | # mults=[[2,2], [1]], 137 | # interior_norms=[[False,False],[False]], 138 | # exterior_norms=[True, True], 139 | # mode = ["base", "geglu"], # geglu 140 | # activation_type='gelu', 141 | # power=1.0 142 | # ): 143 | # super().__init__() 144 | # self.sub_mlps = nn.ModuleList() 145 | # self.norms = nn.ModuleList([ 146 | # nn.LayerNorm(embed_dim) if norm else nn.Identity() for norm in exterior_norms 147 | # ]) 148 | # for i in range(len(mults)): 149 | # self.sub_mlps.append(GPT2MLP(embed_dim, mults[i], interior_norms[i], mode[i], activation_type, power)) 150 | 151 | # def forward(self, hidden_states) -> torch.FloatTensor: 152 | # for mlp, norm in zip(self.sub_mlps, self.norms): 153 | # hidden_states = hidden_states + mlp(norm(hidden_states)) 154 | # return hidden_states 155 | 156 | 157 | 158 | 159 | def patch_model(model, args, exp_args): 160 | for name, module in model.named_modules(): 161 | if hasattr(module, 'ff'): 162 | if exp_args['mode'] == "geglu": 163 | del module.ff 164 | module.ff = GegluFeedForward(module.ff.net[0].in_features, 165 | module.ff.net[0].in_features * exp_args["mult"], 166 | module.ff.net[2].p, 167 | # args.activation, 168 | # args.act_power 169 | ) 170 | elif exp_args['mode'] == "double": 171 | module.ff = DoubleFeedForward(module.ff.net[0].in_features, 172 | module.ff.net[0].in_features * exp_args["mult"], 173 | module.ff.net[2].p, 174 | # args.activation, 175 | # args.act_power 176 | ) 177 | 178 | # init weights again 179 | model.init_weights() 180 | 181 | return model 182 | 183 | 184 | # def get_run_name(args, exp_args): 185 | # run_name = "mult_" + str(exp_args["mults"]) + "_in_" + str(exp_args["interior_norms"]) + "_en_" + str(exp_args["exterior_norms"]) + "_m_" + str(args["mode"]) + "_t_" + str(args["targets"]) + "_lr:" + str(args["learning_rate"]) 186 | # args["output_dir"] = f"{args['base_output_dir']}/{run_name}" 187 | 188 | # return args, run_name 189 | 190 | 191 | def get_run_name(args, exp_args): 192 | run_name = exp_args["mode"] + "_mult_" + str(exp_args["mult"]) + "_t_" + str(exp_args["targets"]) + "_lr:" + str(args["lr"]) 193 | # args["output_dir"] = f"{args['base_output_dir']}/{run_name}" 194 | 195 | return args, run_name 196 | 197 | 198 | extra_args = { 199 | # "mults": [[2,2]], 200 | # "interior_norms": [[False,False]], 201 | # "exterior_norms": [True], 202 | # "mode": ["base"], 203 | # "targets": "all", # all, even, odd, first_half, second_half 204 | 205 | "targets": "all", # all, even, odd, first_half, second_half 206 | "mode": "double", # base, geglu, double 207 | "mult": 2, 208 | 209 | } -------------------------------------------------------------------------------- /multi_token_pred/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ethansmith2000/TransformerExperiments/eb77920ba61e42b87c2090fb01b2c94183cfe43b/multi_token_pred/__init__.py -------------------------------------------------------------------------------- /qknorm_half/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ethansmith2000/TransformerExperiments/eb77920ba61e42b87c2090fb01b2c94183cfe43b/qknorm_half/__init__.py -------------------------------------------------------------------------------- /qknorm_half/gpt.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding=utf-8 3 | # Copyright 2021 The HuggingFace Inc. team. All rights reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | """ 17 | Fine-tuning the library models for causal language modeling (GPT, GPT-2, CTRL, ...) 18 | on a text file or a dataset without using HuggingFace Trainer. 19 | 20 | Here is the full list of checkpoints on the hub that can be fine-tuned by this script: 21 | https://huggingface.co/models?filter=text-generation 22 | """ 23 | # You can also adapt this script on your own causal language modeling task. Pointers for this are left as comments. 24 | 25 | import argparse 26 | import json 27 | import logging 28 | import math 29 | import os 30 | import random 31 | from itertools import chain 32 | from pathlib import Path 33 | 34 | import datasets 35 | import torch 36 | from accelerate import Accelerator, DistributedType 37 | from accelerate.logging import get_logger 38 | from accelerate.utils import set_seed 39 | from datasets import load_dataset 40 | from huggingface_hub import Repository, create_repo 41 | from torch.utils.data import DataLoader 42 | from tqdm.auto import tqdm 43 | 44 | import transformers 45 | from transformers import ( 46 | CONFIG_MAPPING, 47 | MODEL_MAPPING, 48 | AutoConfig, 49 | AutoModelForCausalLM, 50 | AutoTokenizer, 51 | SchedulerType, 52 | default_data_collator, 53 | get_scheduler, 54 | ) 55 | from transformers.utils import check_min_version, send_example_telemetry 56 | from types import SimpleNamespace 57 | 58 | import torch 59 | import torch.nn as nn 60 | from typing import Optional, Tuple, Union 61 | from transformers.models.gpt2.modeling_gpt2 import GPT2Attention 62 | import time 63 | 64 | import pynvml 65 | 66 | 67 | def normalize(tensor): 68 | # eps = torch.finfo(tensor.dtype).eps 69 | eps = 1e-6 70 | norm = tensor.norm(dim=-1, keepdim=True) 71 | norm_clamped = torch.where(norm > eps, norm, eps) 72 | out = tensor / norm_clamped 73 | return out 74 | 75 | class NewGPT2Attention(GPT2Attention): 76 | def __init__(self, config, is_cross_attention=False, layer_idx=None, 77 | mode="knorm",# ["none", "qknorm", "knorm"] 78 | ): 79 | super().__init__(config, is_cross_attention=False, layer_idx=None) 80 | self.query_norm = normalize if mode == "qknorm" else nn.Identity() 81 | self.key_norm = normalize if (mode == "qknorm" or mode == "knorm") else nn.Identity() 82 | self.scaling = self.head_dim ** -0.5 83 | self.softmax_temp = None 84 | if mode == "knorm" or mode == "qknorm": 85 | self.softmax_temp = nn.Parameter(torch.ones(1, self.num_heads, 1, 1) * 10, requires_grad=True) 86 | if mode == "qknorm": 87 | self.scaling = 1.0 88 | print(mode) 89 | 90 | def _attn(self, query, key, value, attention_mask=None, head_mask=None): 91 | query = self.query_norm(query) 92 | key = self.key_norm(key) 93 | if self.softmax_temp is not None: 94 | key = key * self.softmax_temp 95 | query = query * self.scaling 96 | 97 | attn_weights = torch.matmul(query, key.transpose(-1, -2)) 98 | 99 | # Layer-wise attention scaling 100 | if self.scale_attn_by_inverse_layer_idx: 101 | attn_weights = attn_weights / float(self.layer_idx + 1) 102 | 103 | if not self.is_cross_attention: 104 | # if only "normal" attention layer implements causal mask 105 | query_length, key_length = query.size(-2), key.size(-2) 106 | causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length] 107 | mask_value_min = torch.finfo(attn_weights.dtype).min 108 | # Need to be a tensor, otherwise we get error: `RuntimeError: expected scalar type float but found double`. 109 | # Need to be on the same device, otherwise `RuntimeError: ..., x and y to be on the same device` 110 | mask_value_min = torch.full([], mask_value_min, dtype=attn_weights.dtype, device=attn_weights.device) 111 | # mask_value_max = torch.full([], mask_value_max, dtype=attn_weights.dtype, device=attn_weights.device) 112 | attn_weights = torch.where(causal_mask, attn_weights.to(attn_weights.dtype), mask_value_min) 113 | # attn_weights_neg = torch.where(causal_mask, attn_weights.to(attn_weights.dtype), mask_value_max) 114 | 115 | if attention_mask is not None: 116 | # Apply the attention mask 117 | attn_weights = attn_weights + attention_mask 118 | 119 | attn_weights = nn.functional.softmax(attn_weights, dim=-1) 120 | 121 | # Downcast (if necessary) back to V's dtype (if in mixed-precision) -- No-Op otherwise 122 | attn_weights = self.attn_dropout(attn_weights.type(value.dtype)) 123 | 124 | # Mask heads if we want to 125 | if head_mask is not None: 126 | attn_weights = attn_weights * head_mask 127 | 128 | attn_output = torch.matmul(attn_weights, value) 129 | 130 | return attn_output, attn_weights 131 | 132 | 133 | def patch_attn(model, mode="knorm"): 134 | conf = model.config 135 | idx = 0 136 | 137 | for n,m in model.named_modules(): 138 | if hasattr(m, "attn"): 139 | # if idx in indices: 140 | del m.attn 141 | m.add_module("attn", NewGPT2Attention(conf, is_cross_attention=False, layer_idx=None, mode=mode)) 142 | # print("activated", idx) 143 | idx += 1 144 | # print('current idx', idx) 145 | 146 | 147 | extra_args = { 148 | "mode": "knorm", 149 | } 150 | 151 | 152 | def get_run_name(): 153 | base_str = "base" 154 | if args["mode"] == "knorm" or args["mode"] == "knorm": 155 | base_str = args["mode"] 156 | args["output_dir"] = f"{args['base_output_dir']}/{base_str}" -------------------------------------------------------------------------------- /qknorm_half/utils.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | '''Some helper functions for PyTorch, including: 4 | - get_mean_and_std: calculate the mean and std value of dataset. 5 | - msr_init: net parameter initialization. 6 | - progress_bar: progress bar mimic xlua.progress. 7 | ''' 8 | import os 9 | import sys 10 | import time 11 | import math 12 | 13 | import torch.nn as nn 14 | import torch.nn.init as init 15 | import torch 16 | import torchvision 17 | import torchvision.transforms as transforms 18 | from randomaug import RandAugment 19 | 20 | # import sys 21 | # sys.path.append("../.") 22 | # from common.vit import ViT 23 | import sys 24 | sys.path.append(".") 25 | from models.vit import ViT, patch 26 | 27 | def get_mean_and_std(dataset): 28 | '''Compute the mean and std value of dataset.''' 29 | dataloader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=True, num_workers=2) 30 | mean = torch.zeros(3) 31 | std = torch.zeros(3) 32 | print('==> Computing mean and std..') 33 | for inputs, targets in dataloader: 34 | for i in range(3): 35 | mean[i] += inputs[:,i,:,:].mean() 36 | std[i] += inputs[:,i,:,:].std() 37 | mean.div_(len(dataset)) 38 | std.div_(len(dataset)) 39 | return mean, std 40 | 41 | def init_params(net): 42 | '''Init layer parameters.''' 43 | for m in net.modules(): 44 | if isinstance(m, nn.Conv2d): 45 | init.kaiming_normal(m.weight, mode='fan_out') 46 | if m.bias: 47 | init.constant(m.bias, 0) 48 | elif isinstance(m, nn.BatchNorm2d): 49 | init.constant(m.weight, 1) 50 | init.constant(m.bias, 0) 51 | elif isinstance(m, nn.Linear): 52 | init.normal(m.weight, std=1e-3) 53 | if m.bias: 54 | init.constant(m.bias, 0) 55 | 56 | 57 | try: 58 | _, term_width = os.popen('stty size', 'r').read().split() 59 | except: 60 | term_width = 80 61 | term_width = int(term_width) 62 | 63 | TOTAL_BAR_LENGTH = 65. 64 | last_time = time.time() 65 | begin_time = last_time 66 | def progress_bar(current, total, msg=None): 67 | global last_time, begin_time 68 | if current == 0: 69 | begin_time = time.time() # Reset for new bar. 70 | 71 | cur_len = int(TOTAL_BAR_LENGTH*current/total) 72 | rest_len = int(TOTAL_BAR_LENGTH - cur_len) - 1 73 | 74 | sys.stdout.write(' [') 75 | for i in range(cur_len): 76 | sys.stdout.write('=') 77 | sys.stdout.write('>') 78 | for i in range(rest_len): 79 | sys.stdout.write('.') 80 | sys.stdout.write(']') 81 | 82 | cur_time = time.time() 83 | step_time = cur_time - last_time 84 | last_time = cur_time 85 | tot_time = cur_time - begin_time 86 | 87 | L = [] 88 | L.append(' Step: %s' % format_time(step_time)) 89 | L.append(' | Tot: %s' % format_time(tot_time)) 90 | if msg: 91 | L.append(' | ' + msg) 92 | 93 | msg = ''.join(L) 94 | sys.stdout.write(msg) 95 | for i in range(term_width-int(TOTAL_BAR_LENGTH)-len(msg)-3): 96 | sys.stdout.write(' ') 97 | 98 | # Go back to the center of the bar. 99 | for i in range(term_width-int(TOTAL_BAR_LENGTH/2)+2): 100 | sys.stdout.write('\b') 101 | sys.stdout.write(' %d/%d ' % (current+1, total)) 102 | 103 | if current < total-1: 104 | sys.stdout.write('\r') 105 | else: 106 | sys.stdout.write('\n') 107 | sys.stdout.flush() 108 | 109 | def format_time(seconds): 110 | days = int(seconds / 3600/24) 111 | seconds = seconds - days*3600*24 112 | hours = int(seconds / 3600) 113 | seconds = seconds - hours*3600 114 | minutes = int(seconds / 60) 115 | seconds = seconds - minutes*60 116 | secondsf = int(seconds) 117 | seconds = seconds - secondsf 118 | millis = int(seconds*1000) 119 | 120 | f = '' 121 | i = 1 122 | if days > 0: 123 | f += str(days) + 'D' 124 | i += 1 125 | if hours > 0 and i <= 2: 126 | f += str(hours) + 'h' 127 | i += 1 128 | if minutes > 0 and i <= 2: 129 | f += str(minutes) + 'm' 130 | i += 1 131 | if secondsf > 0 and i <= 2: 132 | f += str(secondsf) + 's' 133 | i += 1 134 | if millis > 0 and i <= 2: 135 | f += str(millis) + 'ms' 136 | i += 1 137 | if f == '': 138 | f = '0ms' 139 | return f 140 | 141 | 142 | def load_model(args): 143 | print('==> Building model..') 144 | net = ViT( 145 | image_size = args.size, 146 | patch_size = args.patch, 147 | num_classes = 10, 148 | dim = int(args.dim), 149 | depth = 6, 150 | heads = 8, 151 | mlp_dim = args.mlp_dim, 152 | dropout = 0.1, 153 | emb_dropout = 0.1, 154 | ) 155 | 156 | if args.mode == "qknorm" or args.mode == "knorm": 157 | patch(net) 158 | 159 | net = net.to(args.device) 160 | if args.compile: 161 | net_forward = torch.compile(net.forward) 162 | else: 163 | net_forward = net.forward 164 | 165 | return net, net_forward 166 | 167 | 168 | def load_data(args): 169 | # Data 170 | print('==> Preparing data..') 171 | transform_train = transforms.Compose([ 172 | transforms.RandomCrop(32, padding=4), 173 | transforms.Resize(args.size), 174 | transforms.RandomHorizontalFlip(), 175 | transforms.ToTensor(), 176 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), 177 | ]) 178 | 179 | transform_test = transforms.Compose([ 180 | transforms.Resize(args.size), 181 | transforms.ToTensor(), 182 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), 183 | ]) 184 | 185 | # Add RandAugment with N, M(hyperparameter) 186 | if args.aug: 187 | transform_train.transforms.insert(0, RandAugment(2, 14)) 188 | 189 | # Prepare dataset 190 | trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train) 191 | trainloader = torch.utils.data.DataLoader(trainset, batch_size=args.bs, shuffle=True, num_workers=8) 192 | 193 | testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test) 194 | testloader = torch.utils.data.DataLoader(testset, batch_size=100, shuffle=False, num_workers=8) 195 | 196 | #classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck') 197 | 198 | return trainloader, testloader 199 | 200 | 201 | 202 | def train(args, epoch, net, net_forward, trainloader, optimizer, scaler, loss_fn=None, optimizer_callback=None): 203 | print('\nEpoch: %d' % epoch) 204 | net.train() 205 | train_loss = 0 206 | correct = 0 207 | total = 0 208 | mp_dtype = torch.float32 209 | if args.mp_dtype=="bf16": 210 | mp_dtype = torch.bfloat16 211 | elif args.mp_dtype=="fp16": 212 | mp_dtype = torch.float16 213 | for batch_idx, (inputs, targets) in enumerate(trainloader): 214 | inputs, targets = inputs.to(args.device), targets.to(args.device) 215 | # Train with amp 216 | with torch.cuda.amp.autocast(enabled=args.mp_dtype!="fp32", dtype=mp_dtype): 217 | loss, preds = loss_fn(net_forward, inputs, targets) 218 | scaler.scale(loss).backward() 219 | scaler.step(optimizer) 220 | scaler.update() 221 | optimizer.zero_grad(set_to_none=True) 222 | if optimizer_callback is not None: 223 | optimizer_callback(optimizer) 224 | 225 | train_loss += loss.item() 226 | _, predicted = preds.max(1) 227 | total += targets.size(0) 228 | correct += predicted.eq(targets).sum().item() 229 | 230 | progress_bar(batch_idx, len(trainloader), 'Loss: %.3f | Acc: %.3f%% (%d/%d)' 231 | % (train_loss/(batch_idx+1), 100.*correct/total, correct, total)) 232 | return train_loss/(batch_idx+1) 233 | 234 | ##### Validation 235 | def test(args, epoch, net, net_forward, testloader, optimizer, scaler): 236 | net.eval() 237 | test_loss = 0 238 | correct = 0 239 | total = 0 240 | with torch.no_grad(): 241 | for batch_idx, (inputs, targets) in enumerate(testloader): 242 | inputs, targets = inputs.to(args.device), targets.to(args.device) 243 | pred_labels = net_forward(inputs) 244 | loss = nn.CrossEntropyLoss()(pred_labels, targets) 245 | 246 | test_loss += loss.item() 247 | _, predicted = pred_labels.max(1) 248 | total += targets.size(0) 249 | correct += predicted.eq(targets).sum().item() 250 | 251 | progress_bar(batch_idx, len(testloader), 'Loss: %.3f | Acc: %.3f%% (%d/%d)' 252 | % (test_loss/(batch_idx+1), 100.*correct/total, correct, total)) 253 | 254 | # Save checkpoint. 255 | acc = 100.*correct/total 256 | if acc > args.best_acc: 257 | print('Saving..') 258 | state = {"model": net.state_dict(), 259 | "optimizer": optimizer.state_dict(), 260 | "scaler": scaler.state_dict()} 261 | if not os.path.isdir('checkpoint'): 262 | os.mkdir('checkpoint') 263 | torch.save(state, './checkpoint/'+args.net+'-{}-ckpt.t7'.format(args.patch)) 264 | args.best_acc = acc 265 | 266 | os.makedirs("log", exist_ok=True) 267 | content = time.ctime() + ' ' + f'Epoch {epoch}, lr: {optimizer.param_groups[0]["lr"]:.7f}, val loss: {test_loss:.5f}, acc: {(acc):.5f}' 268 | print(content) 269 | with open(f'log/log_{args.net}_patch{args.patch}.txt', 'a') as appender: 270 | appender.write(content + "\n") 271 | return test_loss, acc -------------------------------------------------------------------------------- /qknorm_half/vit.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from einops import rearrange 4 | from torch.nn.functional import scaled_dot_product_attention as sdpa 5 | 6 | ######## 7 | 8 | def normalize(tensor): 9 | eps = 1e-6 if tensor.dtype == torch.float16 else 1e-10 10 | norm = tensor.norm(dim=-1, keepdim=True) 11 | norm_clamped = torch.where(norm > eps, norm, eps) 12 | out = tensor / norm_clamped 13 | return out 14 | 15 | 16 | class AttnNormBase(nn.Module): 17 | 18 | def __init__(self, dim, heads=8, dim_head=64, dropout=0.): 19 | super().__init__() 20 | inner_dim = dim_head * heads 21 | self.heads = heads 22 | self.to_qkv = nn.Linear(dim, inner_dim * 3, bias=False) 23 | self.to_out = nn.Sequential( 24 | nn.Linear(inner_dim, dim), 25 | nn.Dropout(dropout) 26 | ) 27 | self.softmax_temp = torch.nn.Parameter(torch.ones(1, heads, 1, 1) * 10) 28 | 29 | def forward(self, x): 30 | raise NotImplementedError 31 | 32 | class KNormAttention(AttnNormBase): 33 | 34 | def forward(self, x): 35 | q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), (self.to_qkv(x).chunk(3, dim = -1))) 36 | k = normalize(k) * self.softmax_temp 37 | out = sdpa(q, k, v) 38 | out = rearrange(out, 'b h n d -> b n (h d)') 39 | return self.to_out(out) 40 | 41 | 42 | class QNormAttention(AttnNormBase): 43 | 44 | def forward(self, x): 45 | q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), (self.to_qkv(x).chunk(3, dim = -1))) 46 | q = normalize(q) * self.softmax_temp 47 | out = sdpa(q, k, v) 48 | out = rearrange(out, 'b h n d -> b n (h d)') 49 | return self.to_out(out) 50 | 51 | 52 | class QKNormAttention(AttnNormBase): 53 | 54 | def forward(self, x): 55 | q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), (self.to_qkv(x).chunk(3, dim = -1))) 56 | q = normalize(q) 57 | k = normalize(k) * self.softmax_temp 58 | out = sdpa(q, k, v, scale=1.0) 59 | out = rearrange(out, 'b h n d -> b n (h d)') 60 | return self.to_out(out) 61 | 62 | 63 | def patch(vit, mode="knorm"): 64 | for i in range(len(vit.transformer.attns)): 65 | lyr = vit.transformer.attns[i] 66 | if mode == "knorm": 67 | vit.transformer.attns[i] = KNormAttention(dim=lyr.to_qkv.in_features, heads=lyr.heads, dropout=lyr.to_out[1].p) 68 | elif mode == "qnorm": 69 | vit.transformer.attns[i] = QNormAttention(dim=lyr.to_qkv.in_features, heads=lyr.heads, dropout=lyr.to_out[1].p) 70 | elif mode == "qknorm": 71 | vit.transformer.attns[i] = QKNormAttention(dim=lyr.to_qkv.in_features, heads=lyr.heads, dropout=lyr.to_out[1].p) 72 | else: 73 | raise ValueError(f"Unknown mode: {mode}") 74 | 75 | 76 | extra_args = { 77 | "mode": "knorm" 78 | } 79 | 80 | def get_run_name(): 81 | if args.mode == "knorm": 82 | watermark = "k_" + watermark 83 | elif args.mode == "qknorm": 84 | watermark = "qk_" + watermark 85 | elif args.mode == "qnorm": 86 | watermark = "q_" + watermark -------------------------------------------------------------------------------- /relative_optimizers/caution_muon_adam.py: -------------------------------------------------------------------------------- 1 | from typing import Generator 2 | import torch 3 | 4 | # https://github.com/KellerJordan/Muon/blob/master/muon.py 5 | 6 | def zeropower_via_newtonschulz5(G, steps=10, eps=1e-7): 7 | """ 8 | Newton-Schulz iteration to compute the zeroth power / orthogonalization of G. We opt to use a 9 | quintic iteration whose coefficients are selected to maximize the slope at zero. For the purpose 10 | of minimizing steps, it turns out to be empirically effective to keep increasing the slope at 11 | zero even beyond the point where the iteration no longer converges all the way to one everywhere 12 | on the interval. This iteration therefore does not produce UV^T but rather something like US'V^T 13 | where S' is diagonal with S_{ii}' ~ Uniform(0.5, 1.5), which turns out not to hurt model 14 | performance at all relative to UV^T, where USV^T = G is the SVD. 15 | """ 16 | assert len(G.shape) == 2 17 | a, b, c = (3.4445, -4.7750, 2.0315) 18 | X = G.bfloat16() 19 | X /= X.norm() + eps # ensure top singular value <= 1 20 | if G.size(0) > G.size(1): 21 | X = X.T 22 | for _ in range(steps): 23 | A = X @ X.T 24 | B = b * A + c * A @ A 25 | X = a * X + B @ X 26 | if G.size(0) > G.size(1): 27 | X = X.T 28 | return X 29 | 30 | class CautionAdamMuon(torch.optim.Optimizer): 31 | 32 | def __init__( 33 | self, 34 | params, 35 | lr=0.02, 36 | beta1=0.95, 37 | beta2=0.999, 38 | eps=1e-8, 39 | weight_decay=0.01, 40 | ns_steps=6, 41 | nesterov=False, 42 | update_type="adam", 43 | caution_mode="caution", 44 | ): 45 | defaults = dict( 46 | lr=lr, 47 | beta1=beta1, 48 | beta2=beta2, 49 | eps=eps, 50 | weight_decay=weight_decay, 51 | ns_steps=ns_steps, 52 | nesterov=nesterov, 53 | update_type=update_type, 54 | caution_mode=caution_mode, 55 | ) 56 | 57 | super().__init__(params, defaults) 58 | 59 | @torch.no_grad() 60 | def step(self, closure=None): 61 | """Perform a single optimization step. 62 | 63 | Args: 64 | closure (Callable, optional): A closure that reevaluates the model 65 | and returns the loss. 66 | """ 67 | 68 | loss = None 69 | if closure is not None: 70 | with torch.enable_grad(): 71 | loss = closure() 72 | 73 | for group_num, group in enumerate(self.param_groups): 74 | for i, param in enumerate(group["params"]): 75 | grad = param.grad 76 | if grad is None: 77 | continue 78 | 79 | state = self.state[param] 80 | 81 | og_shape = grad.shape 82 | if grad.ndim != 2: 83 | grad = grad.view(grad.size(0), -1) 84 | 85 | if "exp_avg" not in state: 86 | state["exp_avg"] = torch.zeros_like(grad) 87 | state["exp_avg_sq"] = torch.zeros_like(grad) 88 | state["step"] = 0 89 | 90 | # do Adam update 91 | state["step"] += 1 92 | 93 | bias_correction1 = 1 - group["beta1"]**state["step"] 94 | bias_correction2 = 1 - group["beta2"]**state["step"] 95 | scale = bias_correction1 / bias_correction2**0.5 96 | 97 | # first and second moment update 98 | state["exp_avg"].lerp_(grad, 1 - group["beta1"]) 99 | state["exp_avg_sq"].lerp_(grad.pow(2), 1 - group["beta2"]) 100 | 101 | muon_update = grad.lerp_(state["exp_avg"], group["beta1"]) if group["nesterov"] else state["exp_avg"] 102 | 103 | # orthogonalization 104 | muon_update = zeropower_via_newtonschulz5(muon_update, steps=group["ns_steps"]) 105 | 106 | # rescaling 107 | muon_update *= max(1, muon_update.size(0)/muon_update.size(1))**0.5 108 | # muon_update = muon_update.view(og_shape).type_as(param.data) 109 | 110 | # adam update 111 | denom = state["exp_avg_sq"].sqrt().add_(group["eps"]) 112 | adam_update = state["exp_avg"].div(denom) 113 | 114 | # weight decay 115 | param.data.mul_(1 - group["lr"] * group["weight_decay"]) 116 | 117 | # caution mask 118 | if group["caution_mode"] == "caution": 119 | mask = (adam_update * muon_update > 0).to(muon_update.dtype).squeeze() 120 | update = muon_update if group["update_type"] == "muon" else adam_update 121 | param.data.add_(update.squeeze() * mask/(mask.mean()+group["eps"]), alpha=-group["lr"]) 122 | elif group["caution_mode"] == "scaling": 123 | # Calculate cosine similarity using torch's function (faster) 124 | cosine_sim = torch.nn.functional.cosine_similarity(adam_update.flatten().unsqueeze(0), muon_update.flatten().unsqueeze(0)).item() 125 | 126 | # Scale factor based on cosine similarity (higher similarity = higher scale) 127 | # When cosine_sim is 1, scale is 1; when cosine_sim is -1, scale is close to 0 128 | scale_factor = (cosine_sim + 1) # Map from [-1,1] to [0,2] 129 | 130 | # Apply the scaling to the update 131 | update = muon_update if group["update_type"] == "muon" else adam_update 132 | param.data.add_(update * scale_factor, alpha=-group["lr"]) 133 | 134 | 135 | 136 | -------------------------------------------------------------------------------- /relative_optimizers/gpt.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from .relative_adam import RelativeAdam 3 | from .muon import Muon 4 | from .relative_muon import RelativeMuon 5 | from .relative_muon_2 import RelativeMuon2 6 | from torch import nn 7 | from typing import Optional, Tuple 8 | from einops import rearrange 9 | from torch.nn import functional as F 10 | from torch.nn.functional import scaled_dot_product_attention as sdpa 11 | from hf_gpt_blocks.hf_gpt import patch_gpt 12 | from .caution_muon_adam import CautionAdamMuon 13 | 14 | 15 | def patch_optimizer(model, args, exp_args): 16 | no_decay = ["bias", "layer_norm.weight"] 17 | optimizer_grouped_parameters = [ 18 | { 19 | "params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)], 20 | "weight_decay": args.weight_decay, 21 | }, 22 | { 23 | "params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], 24 | "weight_decay": 0.0, 25 | }, 26 | ] 27 | 28 | lr = exp_args.get("lr", args.learning_rate) 29 | weight_decay = exp_args.get("weight_decay", args.weight_decay) 30 | beta1 = exp_args.get("beta1", args.beta1) 31 | beta2 = exp_args.get("beta2", args.beta2) 32 | eps = exp_args.get("eps", args.eps) 33 | 34 | if exp_args["mode"] == "relative_adam": 35 | optimizer = RelativeAdam(optimizer_grouped_parameters, weight_decay=weight_decay, beta1=beta1, beta2=beta2, lr=lr, lr_weight=exp_args["lr_weight"], param_lr=exp_args["param_lr"], param_eps=exp_args["param_eps"], eps=eps) 36 | elif exp_args["mode"] == "adam": 37 | optimizer = torch.optim.AdamW(optimizer_grouped_parameters, lr=lr, betas=(beta1, beta2), eps=eps, fused=not args.compile_optimizer, weight_decay=weight_decay) 38 | elif exp_args["mode"] == "muon": 39 | optimizer = Muon(optimizer_grouped_parameters, lr=lr, beta1=beta1, eps=eps, weight_decay=weight_decay) 40 | elif exp_args["mode"] == "relative_muon": 41 | optimizer = RelativeMuon(optimizer_grouped_parameters, lr=lr, beta1=beta1, eps=eps, weight_decay=weight_decay, param_lr=exp_args["param_lr"], param_eps=exp_args["param_eps"], lr_weight=exp_args["lr_weight"], lr_cap=exp_args["lr_cap"]) 42 | elif exp_args["mode"] == "relative_muon_2": 43 | optimizer = RelativeMuon2(optimizer_grouped_parameters, lr=lr, beta1=beta1, eps=eps, weight_decay=weight_decay, param_lr=exp_args["param_lr"], lr_weight=exp_args["lr_weight"]) 44 | elif exp_args["mode"] == "caution_muon_adam": 45 | optimizer = CautionAdamMuon(optimizer_grouped_parameters, lr=lr, beta1=beta1, beta2=beta2, eps=eps, weight_decay=weight_decay, update_type=exp_args["update_type"], caution_mode=exp_args["caution_mode"]) 46 | else: 47 | raise ValueError(f"Invalid optimizer: {exp_args['mode']}") 48 | 49 | return optimizer 50 | 51 | 52 | def patch_model(model, args, exp_args): 53 | patch_gpt(model) 54 | return model 55 | 56 | 57 | def get_run_name(args, exp_args): 58 | run_name = exp_args["mode"] 59 | 60 | # if exp_args["mode"] == "base": 61 | # run_name = "base" 62 | # elif exp_args["mode"] == "relative_adam": 63 | # run_name = "mode_" + exp_args["mode"] + "_p_eps_" + str(exp_args["param_eps"]) + "_p_lr_" + str(exp_args["param_lr"]) + "_lr_w_" + str(exp_args["lr_weight"]) + "_lr_" + str(args["learning_rate"]) 64 | # elif exp_args["mode"] == "relative_adam_2": 65 | # run_name = "mode_" + exp_args["mode"] + "_p_eps_" + str(exp_args["param_eps"]) + "_lr_" + str(args["learning_rate"]) 66 | # else: 67 | # raise ValueError(f"Invalid optimizer: {exp_args['mode']}") 68 | args["output_dir"] = f"{args['base_output_dir']}/{run_name}" 69 | 70 | return args, run_name 71 | 72 | 73 | extra_args = { 74 | # "mode": "adam", # ["base", "relative_adam", "relative_adam_2"] 75 | # "lr": 5.0e-5, 76 | # "weight_decay": 0.01, 77 | # "beta1": 0.9, 78 | # "beta2": 0.99, 79 | 80 | # "mode": "relative_adam", 81 | # "lr": 5.0e-5, 82 | # "weight_decay": 0.1, 83 | # "beta1": 0.9, 84 | # "beta2": 0.98, 85 | # "param_eps": 1e-4, 86 | # "lr_weight": 0.5, 87 | # "param_lr": 0.008, 88 | # "lr_cap": 0.02, 89 | 90 | # "mode": "muon", 91 | # "lr": 2.0e-3, 92 | # "weight_decay": 0.01, 93 | # "beta1": 0.95, 94 | 95 | 96 | # "mode": "relative_muon", 97 | # "lr": 2.0e-3, 98 | # "weight_decay": 0.1, 99 | # "beta1": 0.95, 100 | # "param_lr": 0.1, 101 | # "param_eps": 1e-4, 102 | # "lr_weight": 0.5, 103 | # "lr_cap": 0.5, 104 | 105 | # "mode": "relative_muon_2", 106 | # "lr": 2.0e-3, 107 | # "weight_decay": 0.1, 108 | # "beta1": 0.95, 109 | # "param_lr": 0.1, 110 | # "lr_weight": 0.5, 111 | 112 | "mode": "caution_muon_adam", 113 | "lr": 2.0e-3, 114 | "weight_decay": 0.1, 115 | "beta1": 0.95, 116 | "beta2": 0.999, 117 | "caution_mode": "caution", 118 | "update_type": "muon", 119 | } 120 | -------------------------------------------------------------------------------- /relative_optimizers/muon.py: -------------------------------------------------------------------------------- 1 | from typing import Generator 2 | import torch 3 | 4 | # https://github.com/KellerJordan/Muon/blob/master/muon.py 5 | 6 | def zeropower_via_newtonschulz5(G, steps=10, eps=1e-7): 7 | """ 8 | Newton-Schulz iteration to compute the zeroth power / orthogonalization of G. We opt to use a 9 | quintic iteration whose coefficients are selected to maximize the slope at zero. For the purpose 10 | of minimizing steps, it turns out to be empirically effective to keep increasing the slope at 11 | zero even beyond the point where the iteration no longer converges all the way to one everywhere 12 | on the interval. This iteration therefore does not produce UV^T but rather something like US'V^T 13 | where S' is diagonal with S_{ii}' ~ Uniform(0.5, 1.5), which turns out not to hurt model 14 | performance at all relative to UV^T, where USV^T = G is the SVD. 15 | """ 16 | assert len(G.shape) == 2 17 | a, b, c = (3.4445, -4.7750, 2.0315) 18 | X = G.bfloat16() 19 | X /= X.norm() + eps # ensure top singular value <= 1 20 | if G.size(0) > G.size(1): 21 | X = X.T 22 | for _ in range(steps): 23 | A = X @ X.T 24 | B = b * A + c * A @ A 25 | X = a * X + B @ X 26 | if G.size(0) > G.size(1): 27 | X = X.T 28 | return X 29 | 30 | 31 | class Muon(torch.optim.Optimizer): 32 | 33 | def __init__( 34 | self, 35 | params, 36 | lr=0.02, 37 | beta1=0.95, 38 | eps=1e-8, 39 | weight_decay=0.01, 40 | ns_steps=6, 41 | exp_avg_momentum=True, 42 | nesterov=False, 43 | ): 44 | defaults = dict( 45 | lr=lr, 46 | beta1=beta1, 47 | eps=eps, 48 | weight_decay=weight_decay, 49 | ns_steps=ns_steps, 50 | exp_avg_momentum=exp_avg_momentum, 51 | nesterov=nesterov 52 | ) 53 | 54 | super().__init__(params, defaults) 55 | 56 | @torch.no_grad() 57 | def step(self, closure=None): 58 | """Perform a single optimization step. 59 | 60 | Args: 61 | closure (Callable, optional): A closure that reevaluates the model 62 | and returns the loss. 63 | """ 64 | 65 | loss = None 66 | if closure is not None: 67 | with torch.enable_grad(): 68 | loss = closure() 69 | 70 | for group_num, group in enumerate(self.param_groups): 71 | for i, param in enumerate(group["params"]): 72 | grad = param.grad 73 | if grad is None: 74 | continue 75 | 76 | og_shape = grad.shape 77 | if grad.ndim != 2: 78 | grad = grad.view(grad.size(0), -1) 79 | 80 | # do Muon update 81 | state = self.state[param] 82 | if "exp_avg" not in state: 83 | state["exp_avg"] = torch.zeros_like(grad) 84 | state["step"] = 0 85 | 86 | state["step"] += 1 87 | 88 | # momentum update 89 | if group['exp_avg_momentum']: 90 | state["exp_avg"].lerp_(grad, 1 - group["beta1"]) 91 | else: 92 | state["exp_avg"].mul_(group["beta1"]).add_(grad) 93 | 94 | update = grad.lerp_(state["exp_avg"], group["beta1"]) if group["nesterov"] else state["exp_avg"] 95 | 96 | # orthogonalization 97 | g = zeropower_via_newtonschulz5(update, steps=group["ns_steps"]) 98 | 99 | # rescaling 100 | g *= max(1, g.size(0)/g.size(1))**0.5 101 | g = g.view(og_shape).type_as(param.data) 102 | 103 | # update and weight decay 104 | param.data.add_(g, alpha=-group["lr"]) 105 | param.data.mul_(1 - group["lr"] * group["weight_decay"]) 106 | 107 | 108 | 109 | -------------------------------------------------------------------------------- /relative_optimizers/relative_adam.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torch.distributed as dist 4 | 5 | 6 | class RelativeAdam(torch.optim.Optimizer): 7 | def __init__( 8 | self, 9 | params, 10 | lr=1e-4, 11 | weight_decay=0.01, 12 | beta1=0.9, 13 | beta2=0.999, 14 | eps=1e-8, 15 | lr_weight=0.5, 16 | param_lr=0.0075, 17 | param_eps=1e-4, 18 | lr_cap=0.02 19 | ): 20 | defaults = dict( 21 | lr=lr, 22 | orig_lr=lr, 23 | weight_decay=weight_decay, 24 | beta1=beta1, 25 | beta2=beta2, 26 | eps=eps, 27 | lr_weight=lr_weight, 28 | param_lr=param_lr, 29 | param_eps=param_eps, 30 | lr_cap=lr_cap 31 | ) 32 | super().__init__(params, defaults) 33 | 34 | def step(self, closure=None): 35 | for group in self.param_groups: 36 | for p in group['params']: 37 | g = p.grad 38 | if g is None: 39 | continue 40 | state = self.state[p] 41 | if 'step' not in state: 42 | state['step'] = 0 43 | state['exp_avg'] = torch.zeros_like(g) 44 | state['exp_avg_sq'] = torch.zeros_like(g) 45 | 46 | state['step'] += 1 47 | 48 | # update momentum and variance 49 | state['exp_avg'].lerp_(g, 1 - group['beta1']) 50 | state['exp_avg_sq'].lerp_(g.square(), 1 - group['beta2']) 51 | 52 | # the update 53 | g = state['exp_avg'] / (group['eps'] + state['exp_avg_sq'].sqrt()) 54 | 55 | # bias correction 56 | bias_correction1 = 1 - group['beta1'] ** state['step'] 57 | bias_correction2 = 1 - group['beta2'] ** state['step'] 58 | scale = bias_correction1 / bias_correction2**0.5 59 | 60 | # apply weight decay and update 61 | p.data.mul_(1 - group['lr'] * group['weight_decay']) 62 | 63 | # regular update 64 | p.data.add_(g, alpha=-(group['lr'] * group['lr_weight']) / scale) 65 | 66 | # parameter-level learning rate 67 | 68 | # to handle lr scheduling 69 | ratio = group['lr'] / group['orig_lr'] 70 | param_lr = group['param_lr'] * ratio 71 | 72 | p.data.add_( 73 | torch.clamp( 74 | g * (p.abs() + group['param_eps']), 75 | max=group['lr_cap'], 76 | min=-group['lr_cap'] 77 | ), 78 | alpha=-(param_lr * (1-group['lr_weight'])) / scale 79 | ) 80 | -------------------------------------------------------------------------------- /relative_optimizers/relative_muon.py: -------------------------------------------------------------------------------- 1 | from typing import Generator 2 | import torch 3 | 4 | # https://github.com/KellerJordan/Muon/blob/master/muon.py 5 | 6 | def zeropower_via_newtonschulz5(G, steps=10, eps=1e-7): 7 | """ 8 | Newton-Schulz iteration to compute the zeroth power / orthogonalization of G. We opt to use a 9 | quintic iteration whose coefficients are selected to maximize the slope at zero. For the purpose 10 | of minimizing steps, it turns out to be empirically effective to keep increasing the slope at 11 | zero even beyond the point where the iteration no longer converges all the way to one everywhere 12 | on the interval. This iteration therefore does not produce UV^T but rather something like US'V^T 13 | where S' is diagonal with S_{ii}' ~ Uniform(0.5, 1.5), which turns out not to hurt model 14 | performance at all relative to UV^T, where USV^T = G is the SVD. 15 | """ 16 | assert len(G.shape) == 2 17 | a, b, c = (3.4445, -4.7750, 2.0315) 18 | X = G.bfloat16() 19 | X /= X.norm() + eps # ensure top singular value <= 1 20 | if G.size(0) > G.size(1): 21 | X = X.T 22 | for _ in range(steps): 23 | A = X @ X.T 24 | B = b * A + c * A @ A 25 | X = a * X + B @ X 26 | if G.size(0) > G.size(1): 27 | X = X.T 28 | return X 29 | 30 | 31 | class RelativeMuon(torch.optim.Optimizer): 32 | 33 | def __init__( 34 | self, 35 | params, 36 | lr=0.02, 37 | beta1=0.95, 38 | eps=1e-8, 39 | weight_decay=0.01, 40 | ns_steps=6, 41 | exp_avg_momentum=True, 42 | nesterov=False, 43 | param_lr=0.005, 44 | param_eps=1e-4, 45 | lr_weight=0.5, 46 | lr_cap=0.01, 47 | ): 48 | defaults = dict( 49 | lr=lr, 50 | beta1=beta1, 51 | eps=eps, 52 | weight_decay=weight_decay, 53 | ns_steps=ns_steps, 54 | exp_avg_momentum=exp_avg_momentum, 55 | nesterov=nesterov, 56 | param_lr=param_lr, 57 | param_eps=param_eps, 58 | lr_weight=lr_weight, 59 | lr_cap=lr_cap, 60 | ) 61 | 62 | super().__init__(params, defaults) 63 | 64 | @torch.no_grad() 65 | def step(self, closure=None): 66 | """Perform a single optimization step. 67 | 68 | Args: 69 | closure (Callable, optional): A closure that reevaluates the model 70 | and returns the loss. 71 | """ 72 | 73 | loss = None 74 | if closure is not None: 75 | with torch.enable_grad(): 76 | loss = closure() 77 | 78 | for group_num, group in enumerate(self.param_groups): 79 | for i, param in enumerate(group["params"]): 80 | grad = param.grad 81 | if grad is None: 82 | continue 83 | 84 | # do Muon update 85 | og_shape = grad.shape 86 | if grad.ndim != 2: 87 | grad = grad.view(grad.size(0), -1) 88 | 89 | state = self.state[param] 90 | if "exp_avg" not in state: 91 | state["exp_avg"] = torch.zeros_like(grad) 92 | state["step"] = 0 93 | 94 | state["step"] += 1 95 | 96 | # momentum update 97 | if group['exp_avg_momentum']: 98 | state["exp_avg"].lerp_(grad, 1 - group["beta1"]) 99 | else: 100 | state["exp_avg"].mul_(group["beta1"]).add_(grad) 101 | 102 | update = grad.lerp_(state["exp_avg"], group["beta1"]) if group["nesterov"] else state["exp_avg"] 103 | 104 | # orthogonalization 105 | g = zeropower_via_newtonschulz5(update, steps=group["ns_steps"]) 106 | 107 | # rescaling 108 | g *= max(1, g.size(0)/g.size(1))**0.5 109 | g = g.view(og_shape).type_as(param.data) 110 | 111 | # weight decay 112 | param.data.mul_(1 - group["lr"] * group["weight_decay"]) 113 | 114 | # regular update 115 | param.data.add_(g, alpha=-(group['lr'] * group['lr_weight'])) 116 | 117 | # parameter-level learning rate 118 | update = (g * (param.abs() + group['param_eps'])) * (-(group['param_lr'] * (1-group['lr_weight']))) 119 | param.data.add_(torch.clamp(update, max=group['lr_cap'], min=-group['lr_cap'])) 120 | 121 | 122 | 123 | -------------------------------------------------------------------------------- /relative_optimizers/relative_muon_2.py: -------------------------------------------------------------------------------- 1 | from typing import Generator 2 | import torch 3 | 4 | # https://github.com/KellerJordan/Muon/blob/master/muon.py 5 | 6 | def zeropower_via_newtonschulz5(G, steps=10, eps=1e-7): 7 | """ 8 | Newton-Schulz iteration to compute the zeroth power / orthogonalization of G. We opt to use a 9 | quintic iteration whose coefficients are selected to maximize the slope at zero. For the purpose 10 | of minimizing steps, it turns out to be empirically effective to keep increasing the slope at 11 | zero even beyond the point where the iteration no longer converges all the way to one everywhere 12 | on the interval. This iteration therefore does not produce UV^T but rather something like US'V^T 13 | where S' is diagonal with S_{ii}' ~ Uniform(0.5, 1.5), which turns out not to hurt model 14 | performance at all relative to UV^T, where USV^T = G is the SVD. 15 | """ 16 | assert len(G.shape) == 2 17 | a, b, c = (3.4445, -4.7750, 2.0315) 18 | X = G.bfloat16() 19 | X /= X.norm() + eps # ensure top singular value <= 1 20 | if G.size(0) > G.size(1): 21 | X = X.T 22 | for _ in range(steps): 23 | A = X @ X.T 24 | B = b * A + c * A @ A 25 | X = a * X + B @ X 26 | if G.size(0) > G.size(1): 27 | X = X.T 28 | return X 29 | 30 | 31 | class RelativeMuon2(torch.optim.Optimizer): 32 | 33 | def __init__( 34 | self, 35 | params, 36 | lr=0.02, 37 | beta1=0.95, 38 | eps=1e-8, 39 | weight_decay=0.01, 40 | ns_steps=6, 41 | exp_avg_momentum=True, 42 | nesterov=False, 43 | param_lr=0.005, 44 | lr_weight=0.5, 45 | ): 46 | defaults = dict( 47 | lr=lr, 48 | beta1=beta1, 49 | eps=eps, 50 | weight_decay=weight_decay, 51 | ns_steps=ns_steps, 52 | exp_avg_momentum=exp_avg_momentum, 53 | nesterov=nesterov, 54 | param_lr=param_lr, 55 | lr_weight=lr_weight, 56 | ) 57 | 58 | super().__init__(params, defaults) 59 | 60 | @torch.no_grad() 61 | def step(self, closure=None): 62 | """Perform a single optimization step. 63 | 64 | Args: 65 | closure (Callable, optional): A closure that reevaluates the model 66 | and returns the loss. 67 | """ 68 | 69 | loss = None 70 | if closure is not None: 71 | with torch.enable_grad(): 72 | loss = closure() 73 | 74 | for group_num, group in enumerate(self.param_groups): 75 | for i, param in enumerate(group["params"]): 76 | grad = param.grad 77 | if grad is None: 78 | continue 79 | 80 | # do Muon update 81 | og_shape = grad.shape 82 | if grad.ndim != 2: 83 | grad = grad.view(grad.size(0), -1) 84 | 85 | state = self.state[param] 86 | if "exp_avg" not in state: 87 | state["exp_avg"] = torch.zeros_like(grad) 88 | state["step"] = 0 89 | 90 | state["step"] += 1 91 | 92 | # momentum update 93 | if group['exp_avg_momentum']: 94 | state["exp_avg"].lerp_(grad, 1 - group["beta1"]) 95 | else: 96 | state["exp_avg"].mul_(group["beta1"]).add_(grad) 97 | 98 | update = grad.lerp_(state["exp_avg"], group["beta1"]) if group["nesterov"] else state["exp_avg"] 99 | 100 | # orthogonalization 101 | g = zeropower_via_newtonschulz5(update, steps=group["ns_steps"]) 102 | 103 | # rescaling 104 | g *= max(1, g.size(0)/g.size(1))**0.5 105 | g = g.view(og_shape).type_as(param.data) 106 | 107 | # weight decay 108 | param.data.mul_(1 - group["lr"] * group["weight_decay"]) 109 | 110 | # update 111 | # mom_scaled_update = (g * (update.abs().squeeze())) 112 | mom_scaled_update = (g.abs() * (update.squeeze())) 113 | regular_update = g 114 | param.data.add_(mom_scaled_update, alpha=-(group['param_lr'] * (1 - group['lr_weight']))) 115 | param.data.add_(regular_update, alpha=-(group['lr'] * group['lr_weight'])) 116 | 117 | 118 | 119 | 120 | -------------------------------------------------------------------------------- /relative_optimizers/vit.py: -------------------------------------------------------------------------------- 1 | # https://github.com/lucidrains/vit-pytorch/blob/main/vit_pytorch/vit.py 2 | 3 | import torch 4 | from torch import nn 5 | from einops import rearrange, repeat 6 | from einops.layers.torch import Rearrange 7 | 8 | 9 | class BaselineAttention(nn.Module): 10 | def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0., **kwargs): 11 | super().__init__() 12 | inner_dim = dim_head * heads 13 | self.heads = heads 14 | self.to_q = nn.Linear(dim, inner_dim, bias = False) 15 | self.to_k = nn.Linear(dim, inner_dim, bias = False) 16 | self.to_v = nn.Linear(dim, inner_dim, bias = False) 17 | self.to_o = nn.Linear(inner_dim, dim, bias = True) 18 | self.to_out = nn.Sequential( 19 | nn.Linear(inner_dim, dim), 20 | nn.Dropout(dropout) 21 | ) 22 | # xavier init 23 | nn.init.xavier_normal_(self.to_q.weight) 24 | nn.init.xavier_normal_(self.to_k.weight) 25 | nn.init.xavier_normal_(self.to_v.weight) 26 | nn.init.xavier_normal_(self.to_o[0].weight) 27 | 28 | 29 | def forward(self, x): 30 | q, k, v = self.to_q(x), self.to_k(x), self.to_v(x) 31 | q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), (q,k,v)) 32 | out = torch.nn.functional.scaled_dot_product_attention(q, k, v) 33 | out = rearrange(out, 'b h n d -> b n (h d)') 34 | out = self.to_out(out) 35 | return out 36 | 37 | 38 | class CorrelationAttention(BaselineAttention): 39 | def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0., **kwargs): 40 | super().__init__(dim, heads, dim_head, dropout) 41 | if kwargs.get("mod_q", False): 42 | self.to_q = nn.Identity() 43 | if kwargs.get("mod_k", False): 44 | self.to_k = nn.Identity() 45 | if kwargs.get("mod_v", False): 46 | self.to_v = nn.Identity() 47 | if kwargs.get("mod_o", False): 48 | self.to_o = nn.Sequential( 49 | nn.Identity(), 50 | nn.Dropout(dropout) 51 | ) 52 | 53 | def init_to_identity(module): 54 | if isinstance(module, nn.Linear): 55 | module.weight.data.fill_(0.0) 56 | module.weight.data[:, :] = torch.eye(module.weight.data.shape[1]) 57 | 58 | class CorrelationInitAttention(BaselineAttention): 59 | def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0., **kwargs): 60 | super().__init__(dim, heads, dim_head, dropout) 61 | if kwargs.get("mod_q", False): 62 | init_to_identity(self.to_q) 63 | if kwargs.get("mod_k", False): 64 | init_to_identity(self.to_k) 65 | if kwargs.get("mod_v", False): 66 | init_to_identity(self.to_v) 67 | if kwargs.get("mod_o", False): 68 | init_to_identity(self.to_o[0]) 69 | 70 | # class ResidualLinear(nn.Linear): 71 | # def forward(self, input): 72 | # return input + super().forward(input) 73 | 74 | class ResidualLinear(nn.Linear): 75 | def __init__(self, in_features, out_features, bias=False, scale=1.0): 76 | super().__init__(in_features, out_features, bias) 77 | scale = torch.Tensor([scale]) 78 | self.register_buffer("scale", scale) 79 | 80 | def forward(self, input): 81 | return input * self.scale + super().forward(input) * self.scale 82 | 83 | class ResidualAttention(BaselineAttention): 84 | def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0., **kwargs): 85 | super().__init__(dim, heads, dim_head, dropout) 86 | if kwargs.get("mod_q", False): 87 | self.to_q = ResidualLinear(dim, dim, bias=False) 88 | if kwargs.get("mod_k", False): 89 | self.to_k = ResidualLinear(dim, dim, bias=False) 90 | if kwargs.get("mod_v", False): 91 | self.to_v = ResidualLinear(dim, dim, bias=False) 92 | if kwargs.get("mod_o", False): 93 | self.to_o = nn.Sequential( 94 | ResidualLinear(dim, dim, bias=True), 95 | nn.Dropout(dropout) 96 | ) 97 | 98 | def patch_model(model, args, exp_args): 99 | for name, m in model.named_modules(): 100 | if hasattr(m, "attn"): 101 | if exp_args["mode"] == "correlation": 102 | m.attn = CorrelationAttention(model.config, **exp_args) 103 | elif exp_args["mode"] == "correlation_init": 104 | m.attn = CorrelationInitAttention(model.config, **exp_args) 105 | elif exp_args["mode"] == "residual": 106 | m.attn = ResidualAttention(model.config, **exp_args) 107 | elif exp_args["mode"] == "base": 108 | m.attn = BaselineAttention(model.config, **exp_args) 109 | elif exp_args["mode"] == "dummy": 110 | m.attn = nn.Identity() 111 | else: 112 | raise ValueError(f"Invalid mode: {exp_args['mode']}") 113 | return model -------------------------------------------------------------------------------- /residual_stream_scale/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ethansmith2000/TransformerExperiments/eb77920ba61e42b87c2090fb01b2c94183cfe43b/residual_stream_scale/__init__.py -------------------------------------------------------------------------------- /residual_stream_scale/something.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [ 8 | { 9 | "name": "stderr", 10 | "output_type": "stream", 11 | "text": [ 12 | "/home/ubuntu/miniconda3/lib/python3.11/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", 13 | " from .autonotebook import tqdm as notebook_tqdm\n" 14 | ] 15 | } 16 | ], 17 | "source": [ 18 | "import diffusers" 19 | ] 20 | }, 21 | { 22 | "cell_type": "code", 23 | "execution_count": 2, 24 | "metadata": {}, 25 | "outputs": [], 26 | "source": [ 27 | "scheduler = diffusers.DDIMScheduler()" 28 | ] 29 | }, 30 | { 31 | "cell_type": "code", 32 | "execution_count": 4, 33 | "metadata": {}, 34 | "outputs": [ 35 | { 36 | "data": { 37 | "text/plain": [ 38 | "tensor([999, 998, 997, 996, 995, 994, 993, 992, 991, 990, 989, 988, 987, 986,\n", 39 | " 985, 984, 983, 982, 981, 980, 979, 978, 977, 976, 975, 974, 973, 972,\n", 40 | " 971, 970, 969, 968, 967, 966, 965, 964, 963, 962, 961, 960, 959, 958,\n", 41 | " 957, 956, 955, 954, 953, 952, 951, 950, 949, 948, 947, 946, 945, 944,\n", 42 | " 943, 942, 941, 940, 939, 938, 937, 936, 935, 934, 933, 932, 931, 930,\n", 43 | " 929, 928, 927, 926, 925, 924, 923, 922, 921, 920, 919, 918, 917, 916,\n", 44 | " 915, 914, 913, 912, 911, 910, 909, 908, 907, 906, 905, 904, 903, 902,\n", 45 | " 901, 900, 899, 898, 897, 896, 895, 894, 893, 892, 891, 890, 889, 888,\n", 46 | " 887, 886, 885, 884, 883, 882, 881, 880, 879, 878, 877, 876, 875, 874,\n", 47 | " 873, 872, 871, 870, 869, 868, 867, 866, 865, 864, 863, 862, 861, 860,\n", 48 | " 859, 858, 857, 856, 855, 854, 853, 852, 851, 850, 849, 848, 847, 846,\n", 49 | " 845, 844, 843, 842, 841, 840, 839, 838, 837, 836, 835, 834, 833, 832,\n", 50 | " 831, 830, 829, 828, 827, 826, 825, 824, 823, 822, 821, 820, 819, 818,\n", 51 | " 817, 816, 815, 814, 813, 812, 811, 810, 809, 808, 807, 806, 805, 804,\n", 52 | " 803, 802, 801, 800, 799, 798, 797, 796, 795, 794, 793, 792, 791, 790,\n", 53 | " 789, 788, 787, 786, 785, 784, 783, 782, 781, 780, 779, 778, 777, 776,\n", 54 | " 775, 774, 773, 772, 771, 770, 769, 768, 767, 766, 765, 764, 763, 762,\n", 55 | " 761, 760, 759, 758, 757, 756, 755, 754, 753, 752, 751, 750, 749, 748,\n", 56 | " 747, 746, 745, 744, 743, 742, 741, 740, 739, 738, 737, 736, 735, 734,\n", 57 | " 733, 732, 731, 730, 729, 728, 727, 726, 725, 724, 723, 722, 721, 720,\n", 58 | " 719, 718, 717, 716, 715, 714, 713, 712, 711, 710, 709, 708, 707, 706,\n", 59 | " 705, 704, 703, 702, 701, 700, 699, 698, 697, 696, 695, 694, 693, 692,\n", 60 | " 691, 690, 689, 688, 687, 686, 685, 684, 683, 682, 681, 680, 679, 678,\n", 61 | " 677, 676, 675, 674, 673, 672, 671, 670, 669, 668, 667, 666, 665, 664,\n", 62 | " 663, 662, 661, 660, 659, 658, 657, 656, 655, 654, 653, 652, 651, 650,\n", 63 | " 649, 648, 647, 646, 645, 644, 643, 642, 641, 640, 639, 638, 637, 636,\n", 64 | " 635, 634, 633, 632, 631, 630, 629, 628, 627, 626, 625, 624, 623, 622,\n", 65 | " 621, 620, 619, 618, 617, 616, 615, 614, 613, 612, 611, 610, 609, 608,\n", 66 | " 607, 606, 605, 604, 603, 602, 601, 600, 599, 598, 597, 596, 595, 594,\n", 67 | " 593, 592, 591, 590, 589, 588, 587, 586, 585, 584, 583, 582, 581, 580,\n", 68 | " 579, 578, 577, 576, 575, 574, 573, 572, 571, 570, 569, 568, 567, 566,\n", 69 | " 565, 564, 563, 562, 561, 560, 559, 558, 557, 556, 555, 554, 553, 552,\n", 70 | " 551, 550, 549, 548, 547, 546, 545, 544, 543, 542, 541, 540, 539, 538,\n", 71 | " 537, 536, 535, 534, 533, 532, 531, 530, 529, 528, 527, 526, 525, 524,\n", 72 | " 523, 522, 521, 520, 519, 518, 517, 516, 515, 514, 513, 512, 511, 510,\n", 73 | " 509, 508, 507, 506, 505, 504, 503, 502, 501, 500, 499, 498, 497, 496,\n", 74 | " 495, 494, 493, 492, 491, 490, 489, 488, 487, 486, 485, 484, 483, 482,\n", 75 | " 481, 480, 479, 478, 477, 476, 475, 474, 473, 472, 471, 470, 469, 468,\n", 76 | " 467, 466, 465, 464, 463, 462, 461, 460, 459, 458, 457, 456, 455, 454,\n", 77 | " 453, 452, 451, 450, 449, 448, 447, 446, 445, 444, 443, 442, 441, 440,\n", 78 | " 439, 438, 437, 436, 435, 434, 433, 432, 431, 430, 429, 428, 427, 426,\n", 79 | " 425, 424, 423, 422, 421, 420, 419, 418, 417, 416, 415, 414, 413, 412,\n", 80 | " 411, 410, 409, 408, 407, 406, 405, 404, 403, 402, 401, 400, 399, 398,\n", 81 | " 397, 396, 395, 394, 393, 392, 391, 390, 389, 388, 387, 386, 385, 384,\n", 82 | " 383, 382, 381, 380, 379, 378, 377, 376, 375, 374, 373, 372, 371, 370,\n", 83 | " 369, 368, 367, 366, 365, 364, 363, 362, 361, 360, 359, 358, 357, 356,\n", 84 | " 355, 354, 353, 352, 351, 350, 349, 348, 347, 346, 345, 344, 343, 342,\n", 85 | " 341, 340, 339, 338, 337, 336, 335, 334, 333, 332, 331, 330, 329, 328,\n", 86 | " 327, 326, 325, 324, 323, 322, 321, 320, 319, 318, 317, 316, 315, 314,\n", 87 | " 313, 312, 311, 310, 309, 308, 307, 306, 305, 304, 303, 302, 301, 300,\n", 88 | " 299, 298, 297, 296, 295, 294, 293, 292, 291, 290, 289, 288, 287, 286,\n", 89 | " 285, 284, 283, 282, 281, 280, 279, 278, 277, 276, 275, 274, 273, 272,\n", 90 | " 271, 270, 269, 268, 267, 266, 265, 264, 263, 262, 261, 260, 259, 258,\n", 91 | " 257, 256, 255, 254, 253, 252, 251, 250, 249, 248, 247, 246, 245, 244,\n", 92 | " 243, 242, 241, 240, 239, 238, 237, 236, 235, 234, 233, 232, 231, 230,\n", 93 | " 229, 228, 227, 226, 225, 224, 223, 222, 221, 220, 219, 218, 217, 216,\n", 94 | " 215, 214, 213, 212, 211, 210, 209, 208, 207, 206, 205, 204, 203, 202,\n", 95 | " 201, 200, 199, 198, 197, 196, 195, 194, 193, 192, 191, 190, 189, 188,\n", 96 | " 187, 186, 185, 184, 183, 182, 181, 180, 179, 178, 177, 176, 175, 174,\n", 97 | " 173, 172, 171, 170, 169, 168, 167, 166, 165, 164, 163, 162, 161, 160,\n", 98 | " 159, 158, 157, 156, 155, 154, 153, 152, 151, 150, 149, 148, 147, 146,\n", 99 | " 145, 144, 143, 142, 141, 140, 139, 138, 137, 136, 135, 134, 133, 132,\n", 100 | " 131, 130, 129, 128, 127, 126, 125, 124, 123, 122, 121, 120, 119, 118,\n", 101 | " 117, 116, 115, 114, 113, 112, 111, 110, 109, 108, 107, 106, 105, 104,\n", 102 | " 103, 102, 101, 100, 99, 98, 97, 96, 95, 94, 93, 92, 91, 90,\n", 103 | " 89, 88, 87, 86, 85, 84, 83, 82, 81, 80, 79, 78, 77, 76,\n", 104 | " 75, 74, 73, 72, 71, 70, 69, 68, 67, 66, 65, 64, 63, 62,\n", 105 | " 61, 60, 59, 58, 57, 56, 55, 54, 53, 52, 51, 50, 49, 48,\n", 106 | " 47, 46, 45, 44, 43, 42, 41, 40, 39, 38, 37, 36, 35, 34,\n", 107 | " 33, 32, 31, 30, 29, 28, 27, 26, 25, 24, 23, 22, 21, 20,\n", 108 | " 19, 18, 17, 16, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6,\n", 109 | " 5, 4, 3, 2, 1, 0])" 110 | ] 111 | }, 112 | "execution_count": 4, 113 | "metadata": {}, 114 | "output_type": "execute_result" 115 | } 116 | ], 117 | "source": [ 118 | "scheduler.timesteps" 119 | ] 120 | }, 121 | { 122 | "cell_type": "code", 123 | "execution_count": null, 124 | "metadata": {}, 125 | "outputs": [], 126 | "source": [] 127 | } 128 | ], 129 | "metadata": { 130 | "kernelspec": { 131 | "display_name": "base", 132 | "language": "python", 133 | "name": "python3" 134 | }, 135 | "language_info": { 136 | "codemirror_mode": { 137 | "name": "ipython", 138 | "version": 3 139 | }, 140 | "file_extension": ".py", 141 | "mimetype": "text/x-python", 142 | "name": "python", 143 | "nbconvert_exporter": "python", 144 | "pygments_lexer": "ipython3", 145 | "version": "3.11.5" 146 | } 147 | }, 148 | "nbformat": 4, 149 | "nbformat_minor": 2 150 | } 151 | -------------------------------------------------------------------------------- /sam_optimizers/adam_two_momentum_perturb.py: -------------------------------------------------------------------------------- 1 | from typing import Generator 2 | 3 | import torch 4 | 5 | 6 | class AdamTwoMomentumSAM(torch.optim.Optimizer): 7 | 8 | def __init__( 9 | self, 10 | params, 11 | lr=1e-4, 12 | perturb_lr_ratio=None, 13 | beta1=0.90, 14 | beta1_perturb=0.80, 15 | beta2=0.999, 16 | eps=1e-8, 17 | weight_decay=0.01, 18 | nesterov=False, 19 | perturbation_start_step=50, 20 | ): 21 | perturb_lr_ratio = perturb_lr_ratio or 1.0 22 | defaults = dict( 23 | lr=lr, 24 | perturb_lr_ratio=perturb_lr_ratio, 25 | beta1=beta1, 26 | beta2=beta2, 27 | eps=eps, 28 | weight_decay=weight_decay, 29 | beta1_perturb=beta1_perturb, 30 | nesterov=nesterov, 31 | perturbation_start_step=perturbation_start_step, 32 | ) 33 | 34 | super().__init__(params, defaults) 35 | 36 | @torch.no_grad() 37 | def step(self, closure=None): 38 | """Perform a single optimization step. 39 | 40 | Args: 41 | closure (Callable, optional): A closure that reevaluates the model 42 | and returns the loss. 43 | """ 44 | 45 | loss = None 46 | if closure is not None: 47 | with torch.enable_grad(): 48 | loss = closure() 49 | 50 | for group_num, group in enumerate(self.param_groups): 51 | for i, param in enumerate(group["params"]): 52 | grad = param.grad 53 | if grad is None: 54 | continue 55 | 56 | state = self.state[param] 57 | 58 | if "exp_avg" not in state: 59 | state["exp_avg"] = torch.zeros_like(grad) 60 | state["exp_avg_perturb"] = torch.zeros_like(grad) 61 | state["exp_avg_sq"] = torch.zeros_like(grad) 62 | state["step"] = 0 63 | 64 | # do Adam update 65 | state["step"] += 1 66 | 67 | bias_correction1 = 1 - group["beta1"]**state["step"] 68 | bias_correction2 = 1 - group["beta2"]**state["step"] 69 | scale = bias_correction1 / bias_correction2**0.5 70 | 71 | if state["step"] > 1 and state["step"] > group["perturbation_start_step"]: 72 | # remove last weight decay perturbation 73 | perturb_lr = group["lr"] * group["perturb_lr_ratio"] 74 | denom = state["exp_avg_sq"].sqrt().add_(group["eps"]) 75 | param.data.addcdiv_(state["exp_avg_perturb"], denom, value=-perturb_lr/scale) 76 | ############################################################ 77 | 78 | # momentum update 79 | state["exp_avg"].lerp_(grad, 1 - group["beta1"]) 80 | state["exp_avg_perturb"].lerp_(grad, 1 - group["beta1_perturb"]) 81 | 82 | # exp avg sq update 83 | state["exp_avg_sq"].lerp_(grad.pow(2), 1 - group["beta2"]) 84 | 85 | update = grad.lerp_(state["exp_avg"], group["beta1"]) if group["nesterov"] else state["exp_avg"] 86 | 87 | # update 88 | denom = state["exp_avg_sq"].sqrt().add_(group["eps"]) 89 | param.data.addcdiv_(update, denom, value=-group["lr"]) 90 | 91 | # weight decay 92 | param.data.mul_(1 - group["lr"] * group["weight_decay"]) 93 | 94 | ############################################################ 95 | 96 | # Do other momentum perturbation 97 | if state["step"] > group["perturbation_start_step"]: 98 | perturb_lr = group["lr"] * group["perturb_lr_ratio"] 99 | param.data.addcdiv_(state["exp_avg_perturb"], denom, value=perturb_lr/scale) 100 | 101 | 102 | -------------------------------------------------------------------------------- /sam_optimizers/adam_wd_perturb.py: -------------------------------------------------------------------------------- 1 | from typing import Generator 2 | 3 | import torch 4 | 5 | 6 | class AdamWeightDecaySAM(torch.optim.Optimizer): 7 | 8 | def __init__( 9 | self, 10 | params, 11 | lr=1e-4, 12 | beta1=0.90, 13 | beta2=0.999, 14 | eps=1e-8, 15 | weight_decay=0.01, 16 | 17 | ): 18 | defaults = dict( 19 | lr=lr, 20 | beta1=beta1, 21 | beta2=beta2, 22 | eps=eps, 23 | weight_decay=weight_decay, 24 | ) 25 | 26 | super().__init__(params, defaults) 27 | 28 | @torch.no_grad() 29 | def step(self, closure=None): 30 | """Perform a single optimization step. 31 | 32 | Args: 33 | closure (Callable, optional): A closure that reevaluates the model 34 | and returns the loss. 35 | """ 36 | 37 | loss = None 38 | if closure is not None: 39 | with torch.enable_grad(): 40 | loss = closure() 41 | 42 | for group_num, group in enumerate(self.param_groups): 43 | for i, param in enumerate(group["params"]): 44 | grad = param.grad 45 | if grad is None: 46 | continue 47 | 48 | state = self.state[param] 49 | 50 | if "exp_avg" not in state: 51 | state["exp_avg"] = torch.zeros_like(grad) 52 | state["exp_avg_sq"] = torch.zeros_like(grad) 53 | state["step"] = 0 54 | 55 | # do Adam update 56 | state["step"] += 1 57 | 58 | bias_correction1 = 1 - group["beta1"]**state["step"] 59 | bias_correction2 = 1 - group["beta2"]**state["step"] 60 | scale = bias_correction1 / bias_correction2**0.5 61 | 62 | # remove last weight decay perturbation, 63 | if state["step"] > 1: 64 | param.data.div_(1 - group["lr"] * group["weight_decay"]) 65 | ############################################################ 66 | 67 | # do Adam update 68 | og_shape = grad.shape 69 | if "exp_avg" not in state: 70 | state["exp_avg"] = torch.zeros_like(grad) 71 | state["exp_avg_sq"] = torch.zeros_like(grad) 72 | state["step"] = 0 73 | 74 | 75 | # momentum update 76 | state["exp_avg"].lerp_(grad, 1 - group["momentum"]) 77 | 78 | # exp avg sq update 79 | state["exp_avg_sq"].lerp_(grad.pow(2), 1 - group["beta2"]) 80 | 81 | # update and weight decay 82 | denom = state["exp_avg_sq"].sqrt().add_(group["eps"]) 83 | param.data.addcdiv_(state["exp_avg"], denom, value=-group["lr"]/scale) 84 | param.data.mul_(1 - group["lr"] * group["weight_decay"]) 85 | 86 | ############################################################ 87 | 88 | # Do weight decay perturbation 89 | param.data.mul_(1 - group["lr"] * group["weight_decay"]) 90 | 91 | 92 | -------------------------------------------------------------------------------- /sam_optimizers/gpt.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from typing import Optional, Tuple 4 | from einops import rearrange 5 | from torch.nn import functional as F 6 | from torch.nn.functional import scaled_dot_product_attention as sdpa 7 | from .adam_two_momentum_perturb import AdamTwoMomentumSAM 8 | from .adam_wd_perturb import AdamWeightDecaySAM 9 | from .muon_adam_perturb import MuonAdamSAM 10 | from .nesterov_perturb import NesterovPerturb 11 | from hf_gpt_blocks.hf_gpt import patch_gpt 12 | from .muon import Muon 13 | 14 | 15 | def patch_optimizer(model, args, exp_args): 16 | no_decay = ["bias", "layer_norm.weight"] 17 | optimizer_grouped_parameters = [ 18 | { 19 | "params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)], 20 | "weight_decay": args.weight_decay, 21 | }, 22 | { 23 | "params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], 24 | "weight_decay": 0.0, 25 | }, 26 | ] 27 | 28 | lr = exp_args.get("lr", args.learning_rate) 29 | weight_decay = exp_args.get("weight_decay", args.weight_decay) 30 | beta1 = exp_args.get("beta1", args.beta1) 31 | beta2 = exp_args.get("beta2", args.beta2) 32 | eps = exp_args.get("eps", args.eps) 33 | 34 | if exp_args["mode"] == "adam_two_momentum_perturb": 35 | optimizer = AdamTwoMomentumSAM(optimizer_grouped_parameters, beta1=beta1, beta2=beta2, lr=lr, perturb_lr_ratio=exp_args["perturb_lr_ratio"], beta1_perturb=exp_args["beta1_perturb"], eps=eps) 36 | elif exp_args["mode"] == "adam_wd_perturb": 37 | optimizer = AdamWeightDecaySAM(optimizer_grouped_parameters, beta1=beta1, beta2=beta2, lr=lr, eps=eps) 38 | elif exp_args["mode"] == "muon_adam_perturb": 39 | optimizer = MuonAdamSAM(optimizer_grouped_parameters, beta1=beta1, beta2=beta2, lr=lr, perturb_lr_ratio=exp_args["perturb_lr_ratio"], eps=eps) 40 | elif exp_args["mode"] == "nesterov_perturb": 41 | optimizer = NesterovPerturb(optimizer_grouped_parameters, beta1=beta1, beta2=beta2, lr=lr, perturb_lr_ratio=exp_args["perturb_lr_ratio"], eps=eps) 42 | elif exp_args["mode"] == "muon": 43 | optimizer = Muon(optimizer_grouped_parameters, beta1=beta1, lr=lr, eps=eps) 44 | elif exp_args["mode"] == "adam": 45 | optimizer = torch.optim.AdamW(optimizer_grouped_parameters, lr=lr, betas=(beta1, beta2), eps=eps, fused=not args.compile_optimizer) 46 | else: 47 | raise ValueError(f"Invalid optimizer: {exp_args['mode']}") 48 | 49 | return optimizer 50 | 51 | 52 | def patch_model(model, args, exp_args): 53 | patch_gpt(model) 54 | return model 55 | 56 | 57 | def get_run_name(args, exp_args): 58 | run_name = exp_args["mode"] 59 | args["output_dir"] = f"{args['base_output_dir']}/{run_name}" 60 | 61 | return args, run_name 62 | 63 | 64 | extra_args = { 65 | # "mode": "adam", # ["adam", "muon", "nesterov_perturb", "adam_two_momentum_perturb", "adam_wd_perturb", "muon_adam_perturb"] 66 | # "mode": "adam", 67 | 68 | "mode": "muon", 69 | "lr": 2.0e-3, 70 | "weight_decay": 0.1, 71 | "beta1": 0.95, 72 | 73 | 74 | # "mode": "adam_two_momentum_perturb", 75 | # "perturb_lr_ratio": 1.0, 76 | # "beta1_perturb": 0.95, 77 | 78 | # "mode": "adam_wd_perturb", 79 | 80 | # "mode": "muon_adam_perturb", 81 | # "perturb_lr_ratio": 1e-4, 82 | 83 | # "mode": "nesterov_perturb", 84 | # "perturb_lr_ratio": 1e-4, 85 | 86 | # "mode": "muon_adam_perturb", 87 | # "lr": 2.0e-3, 88 | # "weight_decay": 0.1, 89 | # "beta1": 0.95, 90 | # "beta2": 0.999, 91 | # "perturb_lr_ratio": 0.004, 92 | 93 | 94 | 95 | } 96 | -------------------------------------------------------------------------------- /sam_optimizers/muon.py: -------------------------------------------------------------------------------- 1 | from typing import Generator 2 | from .utils import zeropower_via_newtonschulz5 3 | import torch 4 | 5 | # https://github.com/KellerJordan/Muon/blob/master/muon.py 6 | 7 | 8 | 9 | class Muon(torch.optim.Optimizer): 10 | 11 | def __init__( 12 | self, 13 | params, 14 | lr=0.02, 15 | beta1=0.95, 16 | eps=1e-8, 17 | weight_decay=0.01, 18 | ns_steps=6, 19 | exp_avg_momentum=True, 20 | nesterov=False, 21 | ): 22 | defaults = dict( 23 | lr=lr, 24 | beta1=beta1, 25 | eps=eps, 26 | weight_decay=weight_decay, 27 | ns_steps=ns_steps, 28 | exp_avg_momentum=exp_avg_momentum, 29 | nesterov=nesterov 30 | ) 31 | 32 | super().__init__(params, defaults) 33 | 34 | @torch.no_grad() 35 | def step(self, closure=None): 36 | """Perform a single optimization step. 37 | 38 | Args: 39 | closure (Callable, optional): A closure that reevaluates the model 40 | and returns the loss. 41 | """ 42 | 43 | loss = None 44 | if closure is not None: 45 | with torch.enable_grad(): 46 | loss = closure() 47 | 48 | for group_num, group in enumerate(self.param_groups): 49 | for i, param in enumerate(group["params"]): 50 | grad = param.grad 51 | if grad is None: 52 | continue 53 | 54 | # do Muon update 55 | og_shape = grad.shape 56 | if grad.ndim != 2: 57 | grad = grad.view(grad.size(0), -1) 58 | 59 | state = self.state[param] 60 | if "exp_avg" not in state: 61 | state["exp_avg"] = torch.zeros_like(grad) 62 | state["step"] = 0 63 | 64 | state["step"] += 1 65 | 66 | # momentum update 67 | if group['exp_avg_momentum']: 68 | state["exp_avg"].lerp_(grad, 1 - group["beta1"]) 69 | else: 70 | state["exp_avg"].mul_(group["beta1"]).add_(grad) 71 | 72 | update = grad.lerp_(state["exp_avg"], group["beta1"]) if group["nesterov"] else state["exp_avg"] 73 | 74 | # orthogonalization 75 | g = zeropower_via_newtonschulz5(update, steps=group["ns_steps"]) 76 | 77 | # rescaling 78 | g *= max(1, g.size(0)/g.size(1))**0.5 79 | g = g.view(og_shape).type_as(param.data) 80 | 81 | # update and weight decay 82 | param.data.add_(g, alpha=-group["lr"]) 83 | param.data.mul_(1 - group["lr"] * group["weight_decay"]) 84 | 85 | 86 | 87 | -------------------------------------------------------------------------------- /sam_optimizers/muon_adam_perturb.py: -------------------------------------------------------------------------------- 1 | from typing import Generator 2 | from .utils import zeropower_via_newtonschulz5 3 | import torch 4 | 5 | # https://github.com/KellerJordan/Muon/blob/master/muon.py 6 | 7 | 8 | 9 | class MuonAdamSAM(torch.optim.Optimizer): 10 | 11 | def __init__( 12 | self, 13 | params, 14 | lr=0.02, 15 | perturb_lr_ratio=None, 16 | beta1=0.95, 17 | beta2=0.999, 18 | eps=1e-8, 19 | weight_decay=0.01, 20 | ns_steps=6, 21 | exp_avg_momentum=True, 22 | nesterov=False, 23 | ): 24 | perturb_lr_ratio = perturb_lr_ratio or 1.0 25 | defaults = dict( 26 | lr=lr, 27 | perturb_lr_ratio=perturb_lr_ratio, 28 | beta1=beta1, 29 | beta2=beta2, 30 | eps=eps, 31 | weight_decay=weight_decay, 32 | ns_steps=ns_steps, 33 | exp_avg_momentum=exp_avg_momentum, 34 | nesterov=nesterov 35 | ) 36 | 37 | super().__init__(params, defaults) 38 | 39 | @torch.no_grad() 40 | def step(self, closure=None): 41 | """Perform a single optimization step. 42 | 43 | Args: 44 | closure (Callable, optional): A closure that reevaluates the model 45 | and returns the loss. 46 | """ 47 | 48 | loss = None 49 | if closure is not None: 50 | with torch.enable_grad(): 51 | loss = closure() 52 | 53 | for group_num, group in enumerate(self.param_groups): 54 | for i, param in enumerate(group["params"]): 55 | grad = param.grad 56 | if grad is None: 57 | continue 58 | 59 | state = self.state[param] 60 | 61 | og_shape = grad.shape 62 | if grad.ndim != 2: 63 | grad = grad.view(grad.size(0), -1) 64 | 65 | if "exp_avg" not in state: 66 | state["exp_avg"] = torch.zeros_like(grad) 67 | state["exp_avg_sq"] = torch.zeros_like(grad) 68 | state["step"] = 0 69 | 70 | # do Adam update 71 | state["step"] += 1 72 | 73 | bias_correction1 = 1 - group["beta1"]**state["step"] 74 | bias_correction2 = 1 - group["beta2"]**state["step"] 75 | scale = bias_correction1 / bias_correction2**0.5 76 | 77 | if state["step"] > 1: 78 | # remove last ADAM perturbation, 79 | perturb_lr = group["lr"] * group["perturb_lr_ratio"] 80 | denom = state["exp_avg_sq"].sqrt().add_(group["eps"]) 81 | param.addcdiv_(state["exp_avg"].squeeze(), denom.squeeze(), value=-perturb_lr/scale) 82 | 83 | ############################################################ 84 | 85 | # momentum update 86 | if group['exp_avg_momentum']: 87 | state["exp_avg"].lerp_(grad, 1 - group["beta1"]) 88 | else: 89 | state["exp_avg"].mul_(group["beta1"]).add_(grad) 90 | 91 | # exp avg sq update 92 | state["exp_avg_sq"].lerp_(grad.pow(2), 1 - group["beta2"]) 93 | 94 | update = grad.lerp_(state["exp_avg"], group["beta1"]) if group["nesterov"] else state["exp_avg"] 95 | 96 | # orthogonalization 97 | g = zeropower_via_newtonschulz5(update, steps=group["ns_steps"]) 98 | 99 | # rescaling 100 | g *= max(1, g.size(0)/g.size(1))**0.5 101 | g = g.view(og_shape).type_as(param.data) 102 | 103 | # update and weight decay 104 | param.data.add_(g, alpha=-group["lr"]) 105 | param.data.mul_(1 - group["lr"] * group["weight_decay"]) 106 | 107 | ############################################################ 108 | 109 | # Do adam perturbation 110 | denom = state["exp_avg_sq"].sqrt().add_(group["eps"]) 111 | # notice subtle lr is postivie instead of negative 112 | perturb_lr = group["lr"] * group["perturb_lr_ratio"] 113 | param.data.addcdiv_(state["exp_avg"].squeeze(), denom.squeeze(), value=perturb_lr/scale) 114 | 115 | 116 | -------------------------------------------------------------------------------- /sam_optimizers/nesterov_perturb.py: -------------------------------------------------------------------------------- 1 | from typing import Generator 2 | 3 | import torch 4 | 5 | 6 | class NesterovPerturb(torch.optim.Optimizer): 7 | 8 | def __init__( 9 | self, 10 | params, 11 | lr=1e-4, 12 | perturb_lr_ratio=None, 13 | beta1=0.90, 14 | beta2=0.999, 15 | eps=1e-8, 16 | weight_decay=0.01, 17 | nesterov_as_perturb=False, 18 | ): 19 | perturb_lr_ratio = perturb_lr_ratio or lr 20 | defaults = dict( 21 | lr=lr, 22 | perturb_lr_ratio=perturb_lr_ratio, 23 | beta1=beta1, 24 | beta2=beta2, 25 | eps=eps, 26 | weight_decay=weight_decay, 27 | nesterov_as_perturb=nesterov_as_perturb 28 | ) 29 | 30 | super().__init__(params, defaults) 31 | 32 | @torch.no_grad() 33 | def step(self, closure=None): 34 | """Perform a single optimization step. 35 | 36 | Args: 37 | closure (Callable, optional): A closure that reevaluates the model 38 | and returns the loss. 39 | """ 40 | 41 | loss = None 42 | if closure is not None: 43 | with torch.enable_grad(): 44 | loss = closure() 45 | 46 | for group_num, group in enumerate(self.param_groups): 47 | for i, param in enumerate(group["params"]): 48 | grad = param.grad 49 | if grad is None: 50 | continue 51 | 52 | state = self.state[param] 53 | 54 | if "exp_avg" not in state: 55 | state["exp_avg"] = torch.zeros_like(grad) 56 | state["exp_avg_sq"] = torch.zeros_like(grad) 57 | state["step"] = 0 58 | 59 | # do Adam update 60 | state["step"] += 1 61 | 62 | bias_correction1 = 1 - group["beta1"]**state["step"] 63 | bias_correction2 = 1 - group["beta2"]**state["step"] 64 | scale = bias_correction1 / bias_correction2**0.5 65 | 66 | if state["step"] > 1: 67 | # remove last weight decay perturbation 68 | perturb_lr = group["lr"] * group["perturb_lr_ratio"] 69 | perturb = grad.lerp_(state["exp_avg"], group["beta1"]) if group["nesterov_as_perturb"] else state["exp_avg"] 70 | denom = state["exp_avg_sq"].sqrt().add_(group["eps"]) 71 | param.data.addcdiv_(perturb, denom, value=-perturb_lr/scale) 72 | ############################################################ 73 | 74 | # do Adam update 75 | og_shape = grad.shape 76 | 77 | # momentum update 78 | state["exp_avg"].lerp_(grad, 1 - group["beta1"]) 79 | 80 | # exp avg sq update 81 | state["exp_avg_sq"].lerp_(grad.pow(2), 1 - group["beta2"]) 82 | 83 | update = grad.lerp_(state["exp_avg"], group["beta1"]) if not group["nesterov_as_perturb"] else state["exp_avg"] 84 | 85 | # update 86 | denom = state["exp_avg_sq"].sqrt().add_(group["eps"]) 87 | param.data.addcdiv_(update, denom, value=-perturb_lr) 88 | 89 | # weight decay 90 | param.data.mul_(1 - group["lr"] * group["weight_decay"]) 91 | 92 | ############################################################ 93 | 94 | perturb = grad.lerp_(state["exp_avg"], group["beta1"]) if group["nesterov_as_perturb"] else state["exp_avg"] 95 | 96 | # Do other momentum perturbation 97 | param.data.addcdiv_(perturb, denom, value=perturb_lr/scale) 98 | 99 | 100 | -------------------------------------------------------------------------------- /sam_optimizers/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | # @torch.compile 5 | def zeropower_via_newtonschulz5(G, steps=10, eps=1e-7): 6 | """ 7 | Newton-Schulz iteration to compute the zeroth power / orthogonalization of G. We opt to use a 8 | quintic iteration whose coefficients are selected to maximize the slope at zero. For the purpose 9 | of minimizing steps, it turns out to be empirically effective to keep increasing the slope at 10 | zero even beyond the point where the iteration no longer converges all the way to one everywhere 11 | on the interval. This iteration therefore does not produce UV^T but rather something like US'V^T 12 | where S' is diagonal with S_{ii}' ~ Uniform(0.5, 1.5), which turns out not to hurt model 13 | performance at all relative to UV^T, where USV^T = G is the SVD. 14 | """ 15 | assert len(G.shape) == 2 16 | a, b, c = (3.4445, -4.7750, 2.0315) 17 | X = G.bfloat16() 18 | X /= X.norm() + eps # ensure top singular value <= 1 19 | if G.size(0) > G.size(1): 20 | X = X.T 21 | for _ in range(steps): 22 | A = X @ X.T 23 | B = b * A + c * A @ A 24 | X = a * X + B @ X 25 | if G.size(0) > G.size(1): 26 | X = X.T 27 | return X 28 | -------------------------------------------------------------------------------- /sparsemax_attn/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ethansmith2000/TransformerExperiments/eb77920ba61e42b87c2090fb01b2c94183cfe43b/sparsemax_attn/__init__.py -------------------------------------------------------------------------------- /sparsemax_attn/bert/patch.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import transformers 3 | from torch import nn 4 | import math 5 | 6 | from typing import Optional, Tuple 7 | 8 | #model.bert.encoder.layer[i].attention.self 9 | 10 | class PatchedBertSelfAttention(nn.Module): 11 | def __init__(self, config, position_embedding_type=None, dipole_attn=False): 12 | super().__init__() 13 | if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"): 14 | raise ValueError( 15 | f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention " 16 | f"heads ({config.num_attention_heads})" 17 | ) 18 | 19 | self.num_attention_heads = config.num_attention_heads 20 | self.attention_head_size = int(config.hidden_size / config.num_attention_heads) 21 | self.all_head_size = self.num_attention_heads * self.attention_head_size 22 | 23 | self.query = nn.Linear(config.hidden_size, self.all_head_size) 24 | self.key = nn.Linear(config.hidden_size, self.all_head_size) 25 | self.value = nn.Linear(config.hidden_size, self.all_head_size) 26 | 27 | self.dropout = nn.Dropout(config.attention_probs_dropout_prob) 28 | self.position_embedding_type = position_embedding_type or getattr( 29 | config, "position_embedding_type", "absolute" 30 | ) 31 | if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query": 32 | self.max_position_embeddings = config.max_position_embeddings 33 | self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size) 34 | 35 | self.is_decoder = config.is_decoder 36 | 37 | ########## 38 | self.dipole_attn = dipole_attn 39 | self.pos_weights = nn.Parameter(torch.ones(1, self.num_attention_heads, 1, 1)) 40 | self.neg_weights = nn.Parameter(torch.ones(1, self.num_attention_heads, 1, 1)) 41 | self.value2 = nn.Linear(config.hidden_size, self.all_head_size) 42 | ########## 43 | 44 | def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor: 45 | new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) 46 | x = x.view(new_x_shape) 47 | return x.permute(0, 2, 1, 3) 48 | 49 | def forward( 50 | self, 51 | hidden_states: torch.Tensor, 52 | attention_mask: Optional[torch.FloatTensor] = None, 53 | head_mask = None, 54 | encoder_hidden_states: Optional[torch.FloatTensor] = None, 55 | encoder_attention_mask: Optional[torch.FloatTensor] = None, 56 | past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, 57 | output_attentions: Optional[bool] = False, 58 | ) -> Tuple[torch.Tensor]: 59 | mixed_query_layer = self.query(hidden_states) 60 | 61 | # If this is instantiated as a cross-attention module, the keys 62 | # and values come from an encoder; the attention mask needs to be 63 | # such that the encoder's padding tokens are not attended to. 64 | is_cross_attention = encoder_hidden_states is not None 65 | 66 | if is_cross_attention and past_key_value is not None: 67 | # reuse k,v, cross_attentions 68 | key_layer = past_key_value[0] 69 | value_layer = past_key_value[1] 70 | value_2_layer = past_key_value[2] 71 | attention_mask = encoder_attention_mask 72 | else: 73 | ctx = encoder_hidden_states if is_cross_attention else hidden_states 74 | key_layer = self.transpose_for_scores(self.key(ctx)) 75 | value_layer = self.transpose_for_scores(self.value(ctx)) 76 | value_2_layer = self.transpose_for_scores(self.value2(ctx)) 77 | if is_cross_attention: 78 | attention_mask = encoder_attention_mask 79 | elif past_key_value is not None: 80 | key_layer = torch.cat([past_key_value[0], key_layer], dim=2) 81 | value_layer = torch.cat([past_key_value[1], value_layer], dim=2) 82 | value_2_layer = torch.cat([past_key_value[2], value_2_layer], dim=2) 83 | 84 | query_layer = self.transpose_for_scores(mixed_query_layer) 85 | 86 | use_cache = past_key_value is not None 87 | if self.is_decoder: 88 | # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. 89 | # Further calls to cross_attention layer can then reuse all cross-attention 90 | # key/value_states (first "if" case) 91 | # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of 92 | # all previous decoder key/value_states. Further calls to uni-directional self-attention 93 | # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) 94 | # if encoder bi-directional self-attention `past_key_value` is always `None` 95 | past_key_value = (key_layer, value_layer, value_2_layer) 96 | 97 | # Take the dot product between "query" and "key" to get the raw attention scores. 98 | attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) 99 | 100 | if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query": 101 | query_length, key_length = query_layer.shape[2], key_layer.shape[2] 102 | if use_cache: 103 | position_ids_l = torch.tensor(key_length - 1, dtype=torch.long, device=hidden_states.device).view( 104 | -1, 1 105 | ) 106 | else: 107 | position_ids_l = torch.arange(query_length, dtype=torch.long, device=hidden_states.device).view(-1, 1) 108 | position_ids_r = torch.arange(key_length, dtype=torch.long, device=hidden_states.device).view(1, -1) 109 | distance = position_ids_l - position_ids_r 110 | 111 | positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1) 112 | positional_embedding = positional_embedding.to(dtype=query_layer.dtype) # fp16 compatibility 113 | 114 | if self.position_embedding_type == "relative_key": 115 | relative_position_scores = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding) 116 | attention_scores = attention_scores + relative_position_scores 117 | elif self.position_embedding_type == "relative_key_query": 118 | relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding) 119 | relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key_layer, positional_embedding) 120 | attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key 121 | 122 | attention_scores = attention_scores / math.sqrt(self.attention_head_size) 123 | if attention_mask is not None: 124 | attention_scores = attention_scores + attention_mask 125 | 126 | if self.dipole_attn: 127 | attn_probs = torch.exp(attention_scores) 128 | neg_probs = 1 / (attn_probs + 1e-6) 129 | if attention_mask is not None: 130 | attn_mask = attention_mask < -1 131 | neg_probs = torch.where(attn_mask, torch.zeros_like(neg_probs), neg_probs) 132 | attn_probs = attn_probs / attn_probs.sum(dim=-1, keepdim=True) 133 | attn_probs = self.dropout(attn_probs) 134 | neg_probs = neg_probs / neg_probs.sum(dim=-1, keepdim=True) 135 | neg_probs = self.dropout(neg_probs) 136 | 137 | pos_context_layer = torch.matmul(attn_probs, value_layer) * self.pos_weights 138 | neg_context_layer = torch.matmul(neg_probs, value_2_layer) * self.neg_weights 139 | context_layer = pos_context_layer + neg_context_layer 140 | else: 141 | # Normalize the attention scores to probabilities. 142 | attention_probs = nn.functional.softmax(attention_scores, dim=-1) 143 | attention_probs = self.dropout(attention_probs) 144 | context_layer = torch.matmul(attention_probs, value_layer) 145 | 146 | context_layer = context_layer.permute(0, 2, 1, 3).contiguous() 147 | new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) 148 | context_layer = context_layer.view(new_context_layer_shape) 149 | 150 | outputs = (context_layer, attention_probs) if output_attentions else (context_layer,) 151 | 152 | if self.is_decoder: 153 | outputs = outputs + (past_key_value,) 154 | return outputs 155 | 156 | 157 | def patch_model(model, dipole_attn=False, last_n_layers=2): 158 | for i in range(len(model.bert.encoder.layer)): 159 | if i <= len(model.bert.encoder.layer) - last_n_layers - 1: 160 | continue 161 | model.bert.encoder.layer[i].attention.self = PatchedBertSelfAttention(model.config, dipole_attn=dipole_attn) 162 | # also reinit weights of the other parts 163 | model.bert.encoder.layer[i].attention.output = transformers.models.bert.modeling_bert.BertSelfOutput(model.config) 164 | model.bert.encoder.layer[i].intermediate = transformers.models.bert.modeling_bert.BertIntermediate(model.config) 165 | model.bert.encoder.layer[i].output = transformers.models.bert.modeling_bert.BertOutput(model.config) 166 | 167 | 168 | def gather_params(model, last_n_layers=2): 169 | params = {} 170 | for n, p in model.named_parameters(): 171 | if '.layer.' in n: 172 | layer_num = int(n.split('.')[3]) 173 | if layer_num >= len(model.bert.encoder.layer) - last_n_layers: 174 | params[n] = p 175 | else: 176 | if not n in ['bert.embeddings.word_embeddings.weight', 'bert.embeddings.position_embeddings.weight', 'bert.embeddings.token_type_embeddings.weight', 'bert.embeddings.LayerNorm.weight', 'bert.embeddings.LayerNorm.bias']: 177 | params[n] = p 178 | return params -------------------------------------------------------------------------------- /structured_transformer/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ethansmith2000/TransformerExperiments/eb77920ba61e42b87c2090fb01b2c94183cfe43b/structured_transformer/__init__.py -------------------------------------------------------------------------------- /structured_transformer/common.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import torch.nn.functional as F 4 | 5 | #https://github.com/ethansmith2000/SparseNetworks 6 | 7 | class Relu2(nn.Module): 8 | def forward(self, x): 9 | return F.relu(x).square() 10 | 11 | def get_activation(activation): 12 | if activation == 'relu': 13 | return nn.ReLU() 14 | elif activation == 'silu': 15 | return nn.SiLU() 16 | elif activation == 'gelu': 17 | return nn.GELU() 18 | elif activation == 'tanh': 19 | return nn.Tanh() 20 | elif activation == 'sigmoid': 21 | return nn.Sigmoid() 22 | elif activation == "relu2": 23 | return Relu2() 24 | else: 25 | raise NotImplementedError(f"Activation {activation} not implemented") 26 | 27 | class PermuteIn(nn.Module): 28 | 29 | def __init__(self, 30 | full_dim, 31 | heads, 32 | mode="structured", # random, roll, chunk_random, structured 33 | roll=0.4, 34 | chunks=4, # must divide the chunk dim evenly 35 | ): 36 | super().__init__() 37 | block_dim = full_dim // heads 38 | roll = int(roll * full_dim) 39 | if mode == "random": 40 | permute = torch.randperm(full_dim) 41 | elif mode == "roll": 42 | permute = torch.roll(torch.arange(full_dim), roll) 43 | elif mode == "chunk_random": 44 | assert block_dim % chunks == 0, "chunks must divide the dim evenly" 45 | chunk_indices = torch.randperm(full_dim // (block_dim // chunks)) 46 | permute = torch.cat([torch.arange((block_dim // chunks)) + i * (block_dim // chunks) for i in chunk_indices]) 47 | elif mode == "structured": 48 | indices = torch.arange(full_dim) 49 | permute = (indices % heads) * block_dim + (indices // heads) 50 | else: 51 | raise NotImplementedError("mode not implemented") 52 | self.register_buffer("permute", permute) 53 | 54 | def forward(self, x): 55 | return x[:, self.permute] 56 | 57 | 58 | class Unpermute(nn.Module): 59 | 60 | def __init__(self, indices): 61 | super().__init__() 62 | perm_matrix = F.one_hot(indices, num_classes=indices.shape[0]).float() 63 | unperm_matrix = perm_matrix.inverse() 64 | unperm = unperm_matrix.argmax(dim=-1).long() 65 | self.register_buffer("unperm", unperm) 66 | 67 | def forward(self, x): 68 | return x[:, self.unperm] 69 | 70 | 71 | 72 | class AddBias(nn.Module): 73 | 74 | def __init__(self, dim): 75 | super(AddBias, self).__init__() 76 | self.bias = nn.Parameter(torch.zeros(dim)) 77 | 78 | def forward(self, x): 79 | return x + self.bias 80 | 81 | 82 | class LowRankLinear(nn.Module): 83 | """ 84 | Like LoRA but without base layer 85 | """ 86 | def __init__(self, in_features, out_features, rank, bias=True): 87 | super(LowRankLinear, self).__init__() 88 | self.in_features = in_features 89 | self.out_features = out_features 90 | self.rank = rank 91 | self.weight1 = nn.Parameter(torch.Tensor(in_features, rank)) 92 | self.weight2 = nn.Parameter(torch.Tensor(rank, out_features)) 93 | self.add_bias = AddBias(out_features) if bias else torch.nn.Identity() 94 | self.reset_parameters() 95 | 96 | def reset_parameters(self): 97 | nn.init.kaiming_uniform_(self.weight1, a=5**0.5) 98 | nn.init.kaiming_uniform_(self.weight2, a=5**0.5) 99 | if self.bias is not None: 100 | nn.init.zeros_(self.bias) 101 | 102 | def forward(self, x): 103 | return self.add_bias(torch.mm(torch.mm(x, self.weight1), self.weight2)) 104 | 105 | 106 | class SparseLinear(nn.Module): 107 | 108 | """ 109 | Kinda like Monarch Matrices I think 110 | """ 111 | 112 | def __init__(self, full_in_dim=1024, full_out_dim=1024, heads=8, bias=True): 113 | super(SparseLinear, self).__init__() 114 | self.full_in = full_in_dim 115 | self.full_out = full_out_dim 116 | self.in_dim = full_in_dim // heads 117 | self.out_dim = full_out_dim // heads 118 | self.h = heads 119 | weights = [torch.randn(self.in_dim, self.out_dim) for _ in range(heads)] 120 | for i in range(len(weights)): 121 | torch.nn.init.kaiming_uniform_(weights[i], gain=torch.nn.init.calculate_gain('relu')) 122 | self.weight = nn.Parameter(torch.stack(weights, dim=0)) 123 | self.bias_add = AddBias(self.full_out) if bias else nn.Identity() 124 | 125 | def forward(self, x): 126 | b, h, in_dim = x.shape[0], self.h, self.in_dim 127 | x = x.reshape(b, h, in_dim) 128 | x = torch.einsum('bhd,hdl->bhl', x, self.weight) 129 | x = x.reshape(b, h * self.out_dim) 130 | x = self.bias_add(x) 131 | return x 132 | 133 | 134 | 135 | class SparseMLPResidual(nn.Module): 136 | """ 137 | permute/unpermute operation to align with residual stream 138 | """ 139 | 140 | def __init__(self, full_dim=1024, 141 | heads=8, 142 | act="gelu", 143 | full_mlp_dim=4096, 144 | unperm=True, 145 | dropout=0., 146 | permute_mode="structured", # ["random", "roll", "chunk_random", "linear", "structured"] 147 | bias=True 148 | ): 149 | super().__init__() 150 | self.up = SparseLinear(full_dim, full_mlp_dim, heads, bias=bias) 151 | self.down = SparseLinear(full_mlp_dim, full_dim, heads, bias=bias) 152 | self.act = get_activation(act) if act is not None else nn.Identity() 153 | 154 | self.unperm = nn.Identity() 155 | if permute_mode != "linear": 156 | self.perm = PermuteIn(full_dim, heads, mode=permute_mode) 157 | if unperm: 158 | self.unperm = Unpermute(self.perm.permute) 159 | else: 160 | self.perm = nn.Linear(full_dim, full_dim) 161 | if unperm: 162 | self.unperm = nn.Linear(full_dim, full_dim) 163 | 164 | self.dropout = nn.Dropout(dropout) 165 | 166 | def forward(self, x): 167 | x = self.perm(x) # reorder features to have different interactions 168 | x = self.up(x) 169 | x = self.act(x) 170 | x = self.dropout(x) 171 | x = self.down(x) 172 | x = self.dropout(x) 173 | x = self.unperm(x) 174 | 175 | return x 176 | 177 | class SparseFeedForward(SparseMLPResidual): 178 | 179 | def forward(self, x): 180 | b, toks, d = x.shape 181 | x = x.reshape(b * toks, d) 182 | x = super().forward(x) 183 | x = x.reshape(b, toks, d) 184 | return x 185 | 186 | 187 | 188 | class SparseMLP(nn.Module): 189 | """ 190 | Closer to how monarch matrices does it i think 191 | """ 192 | 193 | def __init__(self, full_dim=1024, 194 | heads=8, 195 | act="gelu", 196 | full_mlp_dim=4096, 197 | dropout=0., 198 | permute_mode="structured", # ["random", "roll", "chunk_random", "linear", "structured"] 199 | bias=True 200 | ): 201 | super().__init__() 202 | self.up = SparseLinear(full_dim, full_mlp_dim, heads, bias=bias) 203 | self.down = SparseLinear(full_mlp_dim, full_dim, heads, bias=bias) 204 | self.act = get_activation(act) if act is not None else nn.Identity() 205 | 206 | if permute_mode != "linear": 207 | self.perm = PermuteIn(full_mlp_dim, heads, mode=permute_mode) 208 | else: 209 | self.perm = nn.Linear(full_mlp_dim, full_mlp_dim) 210 | 211 | self.dropout = nn.Dropout(dropout) 212 | 213 | def forward(self, x): 214 | x = self.up(x) 215 | x = self.act(x) 216 | x = self.dropout(x) 217 | x = self.perm(x) 218 | x = self.down(x) 219 | x = self.dropout(x) 220 | 221 | return x 222 | 223 | 224 | -------------------------------------------------------------------------------- /structured_transformer/train_cifar10.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | extra_args = dict( 4 | act = "gelu", 5 | post_attn_act = None, 6 | low_rank = 128, 7 | sparse_heads = 8, 8 | qkv_mode = "low_rank", 9 | to_out_mode = "normal", 10 | ) 11 | 12 | -------------------------------------------------------------------------------- /structured_transformer/train_gpt.py: -------------------------------------------------------------------------------- 1 | 2 | import argparse 3 | import json 4 | import logging 5 | import math 6 | import os 7 | import random 8 | from itertools import chain 9 | from pathlib import Path 10 | 11 | import datasets 12 | import torch 13 | from accelerate import Accelerator, DistributedType 14 | from accelerate.logging import get_logger 15 | from accelerate.utils import set_seed 16 | from datasets import load_dataset 17 | from huggingface_hub import Repository, create_repo 18 | from torch.utils.data import DataLoader 19 | from tqdm.auto import tqdm 20 | 21 | import transformers 22 | from transformers import ( 23 | CONFIG_MAPPING, 24 | MODEL_MAPPING, 25 | AutoConfig, 26 | AutoModelForCausalLM, 27 | AutoTokenizer, 28 | SchedulerType, 29 | default_data_collator, 30 | get_scheduler, 31 | ) 32 | from transformers.utils import check_min_version, send_example_telemetry 33 | from types import SimpleNamespace 34 | 35 | import torch 36 | import torch.nn as nn 37 | from typing import Optional, Tuple, Union 38 | import time 39 | from transformers.models.gpt2.modeling_gpt2 import GPT2Attention 40 | 41 | # Will error if the minimal version of Transformers is not installed. Remove at your own risks. 42 | # check_min_version("4.39.0.dev0") 43 | 44 | logger = get_logger(__name__) 45 | 46 | # require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/language-modeling/requirements.txt") 47 | 48 | MODEL_CONFIG_CLASSES = list(MODEL_MAPPING.keys()) 49 | MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES) 50 | 51 | 52 | 53 | class Conv1D(nn.Module): 54 | """ 55 | 1D-convolutional layer as defined by Radford et al. for OpenAI GPT (and also used in GPT-2). 56 | 57 | Basically works like a linear layer but the weights are transposed. 58 | 59 | Args: 60 | nf (`int`): The number of output features. 61 | nx (`int`): The number of input features. 62 | """ 63 | 64 | def __init__(self, nf, nx): 65 | super().__init__() 66 | self.nf = nf 67 | self.weight = nn.Parameter(torch.empty(nx, nf)) 68 | self.bias = nn.Parameter(torch.zeros(nf)) 69 | nn.init.normal_(self.weight, std=0.02) 70 | 71 | def forward(self, x): 72 | size_out = x.size()[:-1] + (self.nf,) 73 | x = torch.addmm(self.bias, x.view(-1, x.size(-1)), self.weight) 74 | x = x.view(size_out) 75 | return x 76 | 77 | 78 | class GPT2MLP(nn.Module): 79 | def __init__(self, config, activation_type='gelu', power=1.0): 80 | super().__init__() 81 | embed_dim = config.hidden_size 82 | intermediate_size = embed_dim * 4 83 | self.c_fc = Conv1D(intermediate_size, embed_dim) 84 | self.c_proj = Conv1D(embed_dim, intermediate_size) 85 | self.act = Activation(activation_type=activation_type, power=power) 86 | self.dropout = nn.Dropout(config.resid_pdrop) 87 | 88 | def forward(self, hidden_states: Optional[Tuple[torch.FloatTensor]]) -> torch.FloatTensor: 89 | hidden_states = self.c_fc(hidden_states) 90 | hidden_states = self.act(hidden_states) 91 | hidden_states = self.c_proj(hidden_states) 92 | hidden_states = self.dropout(hidden_states) 93 | return hidden_states 94 | 95 | 96 | class NewGPT2Attention(GPT2Attention): 97 | def __init__(self, config, is_cross_attention=False, layer_idx=None, value_act=None, post_attn_act=None, power=1.0): 98 | super().__init__(config, is_cross_attention=is_cross_attention, layer_idx=layer_idx) 99 | # self.value_act = nn.Identity() if value_act is None else Activation(value_act) 100 | # self.post_attn_act = nn.Identity() if post_attn_act is None else Activation(post_attn_act) 101 | dim = self.c_attn.nf // 3 102 | self.value_act = nn.Identity() if value_act is None else LinearAct(dim, dim, activation_type=value_act, power=power, pre_act=True) 103 | self.post_attn_act = nn.Identity() if post_attn_act is None else LinearAct(dim, dim, activation_type=post_attn_act, power=power, pre_act=True) 104 | 105 | def forward( 106 | self, 107 | hidden_states: Optional[Tuple[torch.FloatTensor]], 108 | layer_past: Optional[Tuple[torch.Tensor]] = None, 109 | attention_mask: Optional[torch.FloatTensor] = None, 110 | head_mask: Optional[torch.FloatTensor] = None, 111 | encoder_hidden_states: Optional[torch.Tensor] = None, 112 | encoder_attention_mask: Optional[torch.FloatTensor] = None, 113 | use_cache: Optional[bool] = False, 114 | output_attentions: Optional[bool] = False, 115 | ) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]], ...]: 116 | if encoder_hidden_states is not None: 117 | if not hasattr(self, "q_attn"): 118 | raise ValueError( 119 | "If class is used as cross attention, the weights `q_attn` have to be defined. " 120 | "Please make sure to instantiate class with `GPT2Attention(..., is_cross_attention=True)`." 121 | ) 122 | 123 | query = self.q_attn(hidden_states) 124 | key, value = self.c_attn(encoder_hidden_states).split(self.split_size, dim=2) 125 | attention_mask = encoder_attention_mask 126 | else: 127 | query, key, value = self.c_attn(hidden_states).split(self.split_size, dim=2) 128 | 129 | value = self.value_act(value) 130 | 131 | query = self._split_heads(query, self.num_heads, self.head_dim) 132 | key = self._split_heads(key, self.num_heads, self.head_dim) 133 | value = self._split_heads(value, self.num_heads, self.head_dim) 134 | 135 | if layer_past is not None: 136 | past_key, past_value = layer_past 137 | key = torch.cat((past_key, key), dim=-2) 138 | value = torch.cat((past_value, value), dim=-2) 139 | 140 | if use_cache is True: 141 | present = (key, value) 142 | else: 143 | present = None 144 | 145 | if self.reorder_and_upcast_attn: 146 | attn_output, attn_weights = self._upcast_and_reordered_attn(query, key, value, attention_mask, head_mask) 147 | else: 148 | attn_output, attn_weights = self._attn(query, key, value, attention_mask, head_mask) 149 | 150 | attn_output = self.post_attn_act(attn_output) 151 | 152 | attn_output = self._merge_heads(attn_output, self.num_heads, self.head_dim) 153 | attn_output = self.c_proj(attn_output) 154 | attn_output = self.resid_dropout(attn_output) 155 | 156 | outputs = (attn_output, present) 157 | if output_attentions: 158 | outputs += (attn_weights,) 159 | 160 | return outputs # a, present, (attentions) 161 | 162 | 163 | def patch_attn(model, value_act=None, post_attn_act=None, power=1.0): 164 | conf = model.config 165 | idx = 0 166 | for n,m in model.named_modules(): 167 | if hasattr(m, "attn"): 168 | del m.attn 169 | m.add_module("attn", NewGPT2Attention(conf, is_cross_attention=False, layer_idx=None, value_act=value_act, post_attn_act=post_attn_act, power=power)) 170 | idx += 1 171 | print('current idx', idx) 172 | 173 | 174 | def patch_mlp(model, activation_type, activation_powers): 175 | idx = 0 176 | for n,m in model.named_modules(): 177 | if hasattr(m, "mlp"): 178 | del m.mlp 179 | m.add_module("mlp", GPT2MLP(model.config, activation_type=activation_type[idx], power=activation_powers[idx])) 180 | idx += 1 181 | 182 | -------------------------------------------------------------------------------- /structured_transformer/vit.py: -------------------------------------------------------------------------------- 1 | # https://github.com/lucidrains/vit-pytorch/blob/main/vit_pytorch/vit.py 2 | 3 | import torch 4 | from torch import nn 5 | import torch.nn.functional as F 6 | from einops import rearrange, repeat 7 | from einops.layers.torch import Rearrange 8 | from .common import get_activation, SparseFeedForward, SparseLinear, LowRankLinear, SparseMLPResidual, SparseMLP 9 | 10 | # helpers 11 | 12 | def pair(t): 13 | return t if isinstance(t, tuple) else (t, t) 14 | 15 | # classes 16 | 17 | class PreNorm(nn.Module): 18 | def __init__(self, dim, fn): 19 | super().__init__() 20 | self.norm = nn.LayerNorm(dim) 21 | self.fn = fn 22 | def forward(self, x, **kwargs): 23 | return self.fn(self.norm(x), **kwargs) 24 | 25 | 26 | class FeedForward(nn.Module): 27 | def __init__(self, dim, hidden_dim, dropout = 0., act="gelu"): 28 | super().__init__() 29 | self.net = nn.Sequential( 30 | nn.Linear(dim, hidden_dim), 31 | get_activation(act), 32 | nn.Dropout(dropout), 33 | nn.Linear(hidden_dim, dim), 34 | nn.Dropout(dropout) 35 | ) 36 | def forward(self, x): 37 | return self.net(x) 38 | 39 | 40 | def get_linear(dim, out_dim, mode="normal", rank=None, sparse_heads=None, bias=True): 41 | if mode == "normal": 42 | return nn.Linear(dim, out_dim) 43 | elif mode == "low_rank": 44 | return LowRankLinear(dim, out_dim, rank, bias=bias) 45 | elif mode == "sparse": 46 | return SparseMLP(dim, heads=sparse_heads, full_mlp_dim=out_dim, bias=bias) 47 | else: 48 | raise ValueError("Invalid mode") 49 | 50 | 51 | class Attention(nn.Module): 52 | def __init__(self, 53 | dim=512, 54 | heads = 8, 55 | dropout = 0., 56 | qkv_mode="low_rank", # nomal, low_rank, sparse 57 | to_out_mode="normal", # normal, low_rank, sparse 58 | post_attn_act=None, 59 | rank=None, 60 | sparse_heads=None 61 | ): 62 | super().__init__() 63 | self.heads = heads 64 | self.to_qkv = get_linear(dim, dim * 3, mode=qkv_mode, rank=rank*3, sparse_heads=sparse_heads, bias=False) 65 | self.to_out = nn.Sequential( 66 | get_linear(dim, dim, mode=to_out_mode, rank=rank, sparse_heads=sparse_heads, bias=False), 67 | nn.Dropout(dropout) 68 | ) 69 | self.post_attn_act = nn.Identity() 70 | if post_attn_act is not None: 71 | self.post_attn_act = nn.Sequential(get_linear(dim, dim, mode=to_out_mode, rank=rank, sparse_heads=sparse_heads, bias=False), 72 | get_activation(post_attn_act)) 73 | 74 | 75 | def forward(self, x): 76 | q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), (self.to_qkv(x).chunk(3, dim = -1))) 77 | out = F.scaled_dot_product_attention(q, k, v) 78 | out = rearrange(out, 'b h n d -> b n (h d)') 79 | out = self.post_attn_act(out) 80 | return self.to_out(out) 81 | 82 | 83 | class Transformer(nn.Module): 84 | def __init__(self, dim, 85 | depth, 86 | heads, 87 | mlp_dim, 88 | dropout = 0., 89 | act="gelu", 90 | qkv_mode="low_rank", 91 | to_out_mode="sparse", 92 | post_attn_act=None, 93 | low_rank=None, 94 | sparse_heads=None): 95 | super().__init__() 96 | self.layers = nn.ModuleList([]) 97 | for i in range(depth): 98 | self.layers.append(nn.ModuleList([ 99 | PreNorm(dim, Attention(dim, heads = heads, dropout = dropout, qkv_mode=qkv_mode, to_out_mode=to_out_mode, post_attn_act=post_attn_act, low_rank=low_rank)), 100 | ff = FeedForward(dim, mlp_dim, dropout, act=act) if sparse_heads is None else SparseFeedForward(dim, sparse_heads, full_mlp_dim=mlp_dim, act=act), 101 | PreNorm(dim, FeedForward(dim, mlp_dim, dropout, act=act)) if sparse_heads is None else PreNorm(SparseMLPResidual(dim, heads=sparse_heads, full_mlp_dim=mlp_dim, act=act)) 102 | ])) 103 | def forward(self, x): 104 | for attn, ff in self.layers: 105 | x = attn(x) + x 106 | x = ff(x) + x 107 | return x 108 | 109 | class ViT(nn.Module): 110 | def __init__(self, *, 111 | image_size, 112 | patch_size, 113 | num_classes, 114 | dim, 115 | depth, 116 | heads, 117 | mlp_dim, 118 | channels = 3, 119 | dropout = 0., 120 | emb_dropout = 0., 121 | 122 | act="gelu", 123 | post_attn_act=None, 124 | low_rank=None, 125 | sparse_heads=None, 126 | qkv_mode="low_rank", 127 | to_out_mode="sparse" 128 | ): 129 | super().__init__() 130 | image_height, image_width = pair(image_size) 131 | patch_height, patch_width = pair(patch_size) 132 | 133 | assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.' 134 | 135 | num_patches = (image_height // patch_height) * (image_width // patch_width) 136 | patch_dim = channels * patch_height * patch_width 137 | 138 | self.to_patch_embedding = nn.Sequential( 139 | Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_height, p2 = patch_width), 140 | nn.Linear(patch_dim, dim), 141 | ) 142 | 143 | self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim)) 144 | self.cls_token = nn.Parameter(torch.randn(1, 1, dim)) 145 | self.dropout = nn.Dropout(emb_dropout) 146 | 147 | self.transformer = Transformer(dim, depth, heads, mlp_dim, 148 | dropout, act=act, qkv_mode=qkv_mode, to_out_mode=to_out_mode, 149 | post_attn_act=post_attn_act, 150 | low_rank=low_rank, sparse_heads=sparse_heads) 151 | 152 | 153 | self.mlp_head = nn.Sequential( 154 | nn.LayerNorm(dim), 155 | nn.Linear(dim, num_classes) 156 | ) 157 | 158 | def get_feat(self, img): 159 | x = self.to_patch_embedding(img) 160 | b, n, _ = x.shape 161 | 162 | cls_tokens = repeat(self.cls_token, '() n d -> b n d', b = b) 163 | x = torch.cat((cls_tokens, x), dim=1) 164 | x += self.pos_embedding[:, :(n + 1)] 165 | x = self.dropout(x) 166 | 167 | x = self.transformer(x) 168 | 169 | x = x[:, 0] 170 | 171 | return x 172 | 173 | def forward(self, img): 174 | x = self.get_feat(img) 175 | return self.mlp_head(x) 176 | -------------------------------------------------------------------------------- /train_cifar10.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | ''' 3 | 4 | Train CIFAR10 with PyTorch and Vision Transformers! 5 | written by @kentaroy47, @arutema47 6 | 7 | ''' 8 | 9 | from __future__ import print_function 10 | 11 | import torch 12 | import torch.nn as nn 13 | import torch.optim as optim 14 | import torch.nn.functional as F 15 | import torch.backends.cudnn as cudnn 16 | import numpy as np 17 | 18 | import torchvision 19 | import torchvision.transforms as transforms 20 | 21 | import os 22 | from types import SimpleNamespace 23 | import pandas as pds 24 | import csv 25 | import time 26 | import importlib 27 | from common.cifar_utils import progress_bar, load_data, train, test 28 | from common.vit import ViT 29 | from common.randomaug import RandAugment 30 | 31 | 32 | default_args = dict( 33 | lr = 1e-4, 34 | opt = "adam", 35 | resume = False, 36 | aug = True, 37 | mp_dtype = "bf16", 38 | wandb = True, 39 | mixup = True, 40 | net = "vit", 41 | bs = 512, 42 | size = 32, 43 | n_epochs = 100, 44 | patch = 4, 45 | dim = 512, 46 | depth=6, 47 | num_classes=10, 48 | compile=True, 49 | dropout=0.1, 50 | emb_dropout=0.1, 51 | ) 52 | 53 | experiment_args = dict( 54 | experiment="mlp_mods", 55 | 56 | ) 57 | 58 | 59 | def train_model(args, exp_args): 60 | exp_module = importlib.import_module(f"{exp_args['experiment']}.vit") 61 | 62 | # defaults 63 | extra_args = exp_module.extra_args 64 | for k, v in extra_args.items(): 65 | if k not in exp_args: 66 | exp_args[k] = v 67 | 68 | args, run_name = exp_module.get_run_name(args, exp_args) 69 | 70 | def loss_fn(net_fwd, inputs, targets): 71 | pred_labels = net_forward(inputs) 72 | loss = nn.CrossEntropyLoss()(pred_labels, targets) 73 | return loss, pred_labels 74 | 75 | args = SimpleNamespace(**args) 76 | args.device = 'cuda' if torch.cuda.is_available() else 'cpu' 77 | args.best_acc = 0 # best test accuracy 78 | args.start_epoch = 0 # start from epoch 0 or last checkpoint epoch 79 | 80 | trainloader, testloader = load_data(args) 81 | 82 | net = ViT( 83 | dim=args.dim, 84 | depth=args.depth, 85 | dropout=args.dropout, 86 | emb_dropout=args.emb_dropout, 87 | ) 88 | 89 | exp_module.patch_model(net, exp_args) 90 | net = net.to(args.device) 91 | 92 | net_forward = torch.compile(net.forward) if args.compile else net.forward 93 | 94 | print("NUM PARAMS: ", sum([p.numel() for p in net.parameters()])) 95 | 96 | if args.resume: 97 | # Load checkpoint. 98 | print('==> Resuming from checkpoint..') 99 | assert os.path.isdir('checkpoint'), 'Error: no checkpoint directory found!' 100 | checkpoint = torch.load('./checkpoint/{}-ckpt.t7'.format(args.net)) 101 | net.load_state_dict(checkpoint['net']) 102 | args.best_acc = checkpoint['acc'] 103 | args.start_epoch = checkpoint['epoch'] 104 | 105 | optimizer = optim.AdamW(net.parameters(), lr=args.lr) 106 | scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, args.n_epochs) 107 | 108 | ##### Training 109 | scaler = torch.cuda.amp.GradScaler(enabled=args.mp_dtype == "float16") 110 | list_loss = [] 111 | list_acc = [] 112 | 113 | if args.wandb: 114 | import wandb 115 | wandb.init(project=exp_args['experiment']+"_vit", 116 | name=run_name) 117 | wandb.config.update(args) 118 | wandb.config.update(exp_args) 119 | 120 | if args.wandb: 121 | wandb.watch(net) 122 | 123 | for epoch in range(args.start_epoch, args.n_epochs): 124 | start = time.time() 125 | trainloss = train(args, epoch, net, net_forward, trainloader, optimizer, scaler, loss_fn=loss_fn) 126 | val_loss, acc = test(args, epoch, net, net_forward, testloader, optimizer, scaler) 127 | 128 | scheduler.step(epoch-1) # step cosine scheduling 129 | 130 | list_loss.append(val_loss) 131 | list_acc.append(acc) 132 | 133 | # Log training.. 134 | if args.wandb: 135 | wandb.log({'epoch': epoch, 'train_loss': trainloss, 'val_loss': val_loss, "val_acc": acc, "lr": optimizer.param_groups[0]["lr"], 136 | "epoch_time": time.time()-start}) 137 | 138 | # Write out csv.. 139 | with open(f'log/log_{args.net}_patch{args.patch}.csv', 'w') as f: 140 | writer = csv.writer(f, lineterminator='\n') 141 | writer.writerow(list_loss) 142 | writer.writerow(list_acc) 143 | print(list_loss) 144 | 145 | # writeout wandb 146 | if args.wandb: 147 | wandb.save("wandb_{}.h5".format(args.net)) 148 | wandb.finish() 149 | 150 | if __name__ == '__main__': 151 | args = default_args 152 | train_model(args, experiment_args) 153 | 154 | --------------------------------------------------------------------------------