├── README.md ├── LICENSE └── mLSTM.ipynb /README.md: -------------------------------------------------------------------------------- 1 | # xlstm 2 | my attempts at implementing various bits of Sepp Hochreiter's new xLSTM architecture 3 | very oversimplified and probably somewhat wrong! 4 | please open PRs and make it better. 5 | 6 | 7 | mLSTM: https://github.com/andrewgcodes/xlstm/blob/main/mLSTM.ipynb 8 | 9 | Open In Colab 10 | 11 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 Andrew Kean Gao 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /mLSTM.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "nbformat": 4, 3 | "nbformat_minor": 0, 4 | "metadata": { 5 | "colab": { 6 | "provenance": [] 7 | }, 8 | "kernelspec": { 9 | "name": "python3", 10 | "display_name": "Python 3" 11 | }, 12 | "language_info": { 13 | "name": "python" 14 | } 15 | }, 16 | "cells": [ 17 | { 18 | "cell_type": "markdown", 19 | "source": [ 20 | "[itsandrewgao](https://twitter.com/itsandrewgao) my very rushed and shitty attempt at implementing just the mLSTM to predict sine waves" 21 | ], 22 | "metadata": { 23 | "id": "9PvhlUCt5I-e" 24 | } 25 | }, 26 | { 27 | "cell_type": "markdown", 28 | "source": [ 29 | "\n", 30 | " \"Open\n", 31 | "" 32 | ], 33 | "metadata": { 34 | "id": "9K1Zitz28PWz" 35 | } 36 | }, 37 | { 38 | "cell_type": "code", 39 | "source": [ 40 | "import torch\n", 41 | "import torch.nn as nn\n", 42 | "import numpy as np\n", 43 | "import math\n", 44 | "import matplotlib.pyplot as plt\n", 45 | "\n", 46 | "def generate_sine_wave(seq_len, num_sequences):\n", 47 | " x = np.linspace(0, 2 * np.pi, seq_len)\n", 48 | " y = np.sin(x)\n", 49 | " return torch.tensor(y).float().view(-1, 1).repeat(1, num_sequences).unsqueeze(0)\n", 50 | "\n", 51 | "class mLSTM(nn.Module):\n", 52 | " def __init__(self, input_size, hidden_size, mem_dim):\n", 53 | " super(mLSTM, self).__init__()\n", 54 | " self.input_size = input_size\n", 55 | " self.hidden_size = hidden_size\n", 56 | " self.mem_dim = mem_dim\n", 57 | " self.Wq = nn.Parameter(torch.randn(hidden_size, input_size))\n", 58 | " self.bq = nn.Parameter(torch.randn(hidden_size, 1))\n", 59 | " self.Wk = nn.Parameter(torch.randn(mem_dim, input_size))\n", 60 | " self.bk = nn.Parameter(torch.randn(mem_dim, 1))\n", 61 | " self.Wv = nn.Parameter(torch.randn(mem_dim, input_size))\n", 62 | " self.bv = nn.Parameter(torch.randn(mem_dim, 1))\n", 63 | " self.wi = nn.Parameter(torch.randn(1, input_size))\n", 64 | " self.bi = nn.Parameter(torch.randn(1))\n", 65 | " self.wf = nn.Parameter(torch.randn(1, input_size))\n", 66 | " self.bf = nn.Parameter(torch.randn(1))\n", 67 | " self.Wo = nn.Parameter(torch.randn(hidden_size, input_size))\n", 68 | " self.bo = nn.Parameter(torch.randn(hidden_size, 1))\n", 69 | " self.reset_parameters()\n", 70 | "\n", 71 | " def reset_parameters(self):\n", 72 | " for p in self.parameters():\n", 73 | " if p.data.ndimension() >= 2:\n", 74 | " nn.init.xavier_uniform_(p.data)\n", 75 | " else:\n", 76 | " nn.init.zeros_(p.data)\n", 77 | "\n", 78 | " def forward(self, x, states):\n", 79 | " (C_prev, n_prev) = states\n", 80 | " qt = torch.matmul(self.Wq, x) + self.bq\n", 81 | " kt = (1 / math.sqrt(self.mem_dim)) * (torch.matmul(self.Wk, x) + self.bk)\n", 82 | " vt = torch.matmul(self.Wv, x) + self.bv\n", 83 | "\n", 84 | " it = torch.exp(torch.matmul(self.wi, x) + self.bi)\n", 85 | " ft = torch.sigmoid(torch.matmul(self.wf, x) + self.bf)\n", 86 | "\n", 87 | " vt = vt.squeeze()\n", 88 | " kt = kt.squeeze()\n", 89 | "\n", 90 | " C = ft * C_prev + it * torch.ger(vt, kt)\n", 91 | " n = ft * n_prev + it * kt.unsqueeze(1)\n", 92 | "\n", 93 | " max_nqt = torch.max(torch.abs(torch.matmul(n.T, qt)), torch.tensor(1.0))\n", 94 | " h_tilde = torch.matmul(C, qt) / max_nqt\n", 95 | " ot = torch.sigmoid(torch.matmul(self.Wo, x) + self.bo)\n", 96 | " ht = ot * h_tilde\n", 97 | "\n", 98 | " return ht, (C, n)\n", 99 | "\n", 100 | " def init_hidden(self):\n", 101 | " return (torch.zeros(self.mem_dim, self.mem_dim),\n", 102 | " torch.zeros(self.mem_dim, 1))\n", 103 | "\n", 104 | "input_size = 1\n", 105 | "hidden_size = 10\n", 106 | "mem_dim = 10\n", 107 | "seq_len = 100\n", 108 | "num_sequences = 1\n", 109 | "\n", 110 | "model = mLSTM(input_size, hidden_size, mem_dim)\n", 111 | "optimizer = torch.optim.Adam(model.parameters(), lr=0.01)\n", 112 | "criterion = nn.MSELoss()\n", 113 | "\n", 114 | "data = generate_sine_wave(seq_len, num_sequences)\n", 115 | "\n", 116 | "for epoch in range(200):\n", 117 | " states = model.init_hidden()\n", 118 | " optimizer.zero_grad()\n", 119 | " loss = 0\n", 120 | " for t in range(seq_len - 1):\n", 121 | " x = data[:, t]\n", 122 | " y_true = data[:, t + 1]\n", 123 | " y_pred, states = model(x, states)\n", 124 | " loss += criterion(y_pred, y_true)\n", 125 | "\n", 126 | " loss.backward()\n", 127 | " optimizer.step()\n", 128 | "\n", 129 | " if epoch % 10 == 0:\n", 130 | " print(f'Epoch {epoch} Loss {loss.item()}')\n", 131 | "\n", 132 | "test_output = []\n", 133 | "states = model.init_hidden()\n", 134 | "for t in range(seq_len - 1):\n", 135 | " x = data[:, t]\n", 136 | " y_pred, states = model(x, states)\n", 137 | " test_output.append(y_pred.detach().numpy().ravel()[0])\n", 138 | "\n", 139 | "plt.figure(figsize=(10, 4))\n", 140 | "plt.title('Original vs. Predicted Sine Wave')\n", 141 | "plt.plot(data.numpy().ravel(), label='Original')\n", 142 | "plt.plot(test_output, label='Predicted')\n", 143 | "plt.legend()\n", 144 | "plt.show()\n" 145 | ], 146 | "metadata": { 147 | "colab": { 148 | "base_uri": "https://localhost:8080/", 149 | "height": 738 150 | }, 151 | "id": "aqW1CSpC4GtV", 152 | "outputId": "5e0182e0-439b-43ba-d4ea-473f66be3343" 153 | }, 154 | "execution_count": null, 155 | "outputs": [ 156 | { 157 | "output_type": "stream", 158 | "name": "stdout", 159 | "text": [ 160 | "Epoch 0 Loss 50.32358169555664\n", 161 | "Epoch 10 Loss 42.4661979675293\n", 162 | "Epoch 20 Loss 34.83522033691406\n", 163 | "Epoch 30 Loss 28.28700065612793\n", 164 | "Epoch 40 Loss 23.147504806518555\n", 165 | "Epoch 50 Loss 18.864025115966797\n", 166 | "Epoch 60 Loss 15.362940788269043\n", 167 | "Epoch 70 Loss 12.43962574005127\n", 168 | "Epoch 80 Loss 9.934250831604004\n", 169 | "Epoch 90 Loss 7.787623882293701\n", 170 | "Epoch 100 Loss 5.969943046569824\n", 171 | "Epoch 110 Loss 4.4392194747924805\n", 172 | "Epoch 120 Loss 3.1804585456848145\n", 173 | "Epoch 130 Loss 2.2156834602355957\n", 174 | "Epoch 140 Loss 1.5426660776138306\n", 175 | "Epoch 150 Loss 1.1080152988433838\n", 176 | "Epoch 160 Loss 0.841367781162262\n", 177 | "Epoch 170 Loss 0.6825742125511169\n", 178 | "Epoch 180 Loss 0.5876235961914062\n", 179 | "Epoch 190 Loss 0.5289116501808167\n" 180 | ] 181 | }, 182 | { 183 | "output_type": "display_data", 184 | "data": { 185 | "text/plain": [ 186 | "
" 187 | ], 188 | "image/png": "\n" 189 | }, 190 | "metadata": {} 191 | } 192 | ] 193 | } 194 | ] 195 | } 196 | --------------------------------------------------------------------------------