├── .gitignore ├── README.md ├── examples ├── 1-seq2seq.ipynb ├── 2-gpt.ipynb └── adder │ └── data.py ├── modulo_math.ipynb ├── src ├── alg_1.py ├── alg_10.py ├── alg_2.py ├── alg_3.py ├── alg_4.py ├── alg_5.py ├── alg_6.py ├── alg_7.py ├── alg_8.py └── alg_9.py └── tests └── test_alg5.py /.gitignore: -------------------------------------------------------------------------------- 1 | .ipynb_checkpoints/ 2 | __pycache__/ -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Formal Algorithms For Transformers 2 | 3 | PyTorch implementation of transformer algorithms described in "Formal Algorithms for Transformers" by Mary Phuong and Marcus Hutter: https://arxiv.org/abs/2207.09238 4 | 5 | [Algorithm 1](https://github.com/myazdani/formal-algorithms-for-transformers/blob/main/src/alg_1.py): Token embedding 6 | 7 | [Algorithm 2](https://github.com/myazdani/formal-algorithms-for-transformers/blob/main/src/alg_2.py): Positional embedding 8 | 9 | [Algorithm 3](https://github.com/myazdani/formal-algorithms-for-transformers/blob/main/src/alg_3.py): Basic single-query attention 10 | 11 | [Algorithm 4](https://github.com/myazdani/formal-algorithms-for-transformers/blob/main/src/alg_4.py): 𝑽˜ ← Attention(𝑿, 𝒁|W𝒒𝒌𝒗, Mask) 12 | 13 | [Algorithm 5](https://github.com/myazdani/formal-algorithms-for-transformers/blob/main/src/alg_5.py): 𝑽˜ ← MHAttention(𝑿, 𝒁|W, Mask) 14 | 15 | [Algorithm 6](https://github.com/myazdani/formal-algorithms-for-transformers/blob/main/src/alg_6.py): ˆ𝒆 ← layer_norm(𝒆|𝜸, 𝜷) 16 | 17 | [Algorithm 7](https://github.com/myazdani/formal-algorithms-for-transformers/blob/main/src/alg_7.py): Unembedding. 18 | 19 | [Algorithm 8](https://github.com/myazdani/formal-algorithms-for-transformers/blob/main/src/alg_8.py): 𝑷 ← EDTransformer(𝒛, 𝒙|𝜽) 20 | 21 | [Algorithm 9](https://github.com/myazdani/formal-algorithms-for-transformers/blob/main/src/alg_9.py): 𝑷 ← ETransformer(𝒙|𝜽) 22 | 23 | [Algorithm 10](https://github.com/myazdani/formal-algorithms-for-transformers/blob/main/src/alg_10.py): 𝑷 ← DTransformer(𝒙|𝜽) 24 | 25 | Algorithm 11: 𝜽ˆ ← EDTraining(𝒛1:𝑁data , 𝒙1:𝑁data , 𝜽) 26 | 27 | Algorithm 12: 𝜽ˆ ← ETraining(𝒙1:𝑁data , 𝜽) 28 | 29 | Algorithm 13: 𝜽ˆ ← DTraining(𝒙1:𝑁data , 𝜽) 30 | 31 | Algorithm 14: 𝒚 ← DInference(𝒙, 𝜽ˆ) 32 | -------------------------------------------------------------------------------- /examples/1-seq2seq.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import torch\n", 10 | "import torch.nn as nn\n", 11 | "from torch.nn import functional as F\n", 12 | "\n", 13 | "import sys\n", 14 | "sys.path.insert(1, '../src')\n", 15 | "from alg_8 import EDTransformer" 16 | ] 17 | }, 18 | { 19 | "cell_type": "code", 20 | "execution_count": 2, 21 | "metadata": {}, 22 | "outputs": [], 23 | "source": [ 24 | "max_seq_len = 8\n", 25 | "embed_dim = 50\n", 26 | "vocab_size = 12\n", 27 | "\n", 28 | "bs = 32\n", 29 | "z_ids = torch.randint(0,vocab_size, size = (bs*2, max_seq_len)) \n", 30 | "x_ids = (z_ids+5)%vocab_size" 31 | ] 32 | }, 33 | { 34 | "cell_type": "code", 35 | "execution_count": 3, 36 | "metadata": {}, 37 | "outputs": [], 38 | "source": [ 39 | "ed_seq2seq = EDTransformer(embed_dim=embed_dim, mlp_dim=32, max_seq_len=max_seq_len,\n", 40 | " L_dec=3, L_enc=3, vocab_size=vocab_size, num_heads=3)\n", 41 | "\n", 42 | "\n", 43 | "neg_ll_loss = nn.NLLLoss()\n", 44 | "optimizer = torch.optim.SGD(ed_seq2seq.parameters(), lr=0.10, momentum=0.9)\n", 45 | "losses = []" 46 | ] 47 | }, 48 | { 49 | "cell_type": "code", 50 | "execution_count": 4, 51 | "metadata": {}, 52 | "outputs": [], 53 | "source": [ 54 | "%%time\n", 55 | "for _ in range(1000):\n", 56 | " optimizer.zero_grad()\n", 57 | " output = ed_seq2seq(z_ids, x_ids)\n", 58 | " loss = neg_ll_loss(torch.log(output.view(-1,vocab_size,max_seq_len)), x_ids)\n", 59 | " loss.backward()\n", 60 | " optimizer.step()\n", 61 | " losses.append(loss.item())\n" 62 | ] 63 | }, 64 | { 65 | "cell_type": "code", 66 | "execution_count": null, 67 | "metadata": {}, 68 | "outputs": [], 69 | "source": [ 70 | "import matplotlib.pyplot as plt" 71 | ] 72 | }, 73 | { 74 | "cell_type": "code", 75 | "execution_count": null, 76 | "metadata": {}, 77 | "outputs": [ 78 | { 79 | "data": { 80 | "text/plain": [ 81 | "[]" 82 | ] 83 | }, 84 | "execution_count": 51, 85 | "metadata": {}, 86 | "output_type": "execute_result" 87 | }, 88 | { 89 | "data": { 90 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXoAAAD4CAYAAADiry33AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjUuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/YYfK9AAAACXBIWXMAAAsTAAALEwEAmpwYAAAin0lEQVR4nO3de3Rc5Xnv8e8zo9FdlixbNrZsYy4GYgI2jmpuSQyEUJOWODmHZJmGENIkLgnpSZPeSHOatnDa04Q2q00hUA4xNCmBlRRI3NTcQkgoV1sm5mYwCBuwsLFkyxddrMtonvPHbMljWbLG1khbs+f3WWvW7P2+e0bPK1s/be3Z+93m7oiISHTFwi5ARETGl4JeRCTiFPQiIhGnoBcRiTgFvYhIxBWFXcBwpk+f7vPnzw+7DBGRvLFhw4Zd7l43XN+kDPr58+fT2NgYdhkiInnDzN4aqU+HbkREIk5BLyIScQp6EZGIU9CLiEScgl5EJOIU9CIiEaegFxGJuEl5Hv2x+u6jr5PsT2FmmIEx8Ez62QwGlof2BevpfstoJ+P90o1DXzOwHjOjsqSIKWUJqssSnFRXSVlxfNha3Z3e/hTF8dhgXSIi4yFSQX/rr9+gq7c/7DIGxWPGh06bwUcXz+aff/E6ALdcuYSYGVetXkfzngPMqy3nrz+6kItOmxlytSISVTYZbzzS0NDgY7ky1t1xBx9YhmA93c6Q9aHbcYQ+T3ce9p4OpFJOe3eS9u4+dnf2snHbXu5+9m3ae5KDtX1gwXRqyotZ++IOrr3gJB7etJO3dnfxxJ9fyLTKkmMes4gUNjPb4O4Nw/VFao9+wMChlmAttDo+csYsPnX2PK7/z02cML2CZMq586k3AZhanuBrl5zKhxcex2U3PcG3HnyVb1++KLRaRSS6Ihn0k8nx0yr4/tW/BcC+rj46e5L8ZEMze7r6AHjPrCoAftzYzA0fey/7DvQxo6o0tHpFJHp01s0Eqi5P8O3LzwRgxeLZABTFY3z14lMAOPV/P8jSv32U/3X3b0KrUUSiR0E/wcyMV29Yznc+uXiw7UsXnnTINmue307L/m62tXVNcHUiEkWjBr2ZzTWzx8zsFTN72cy+Msw2ZmbfNbMmM3vBzJZk9C03s81B33W5HkA+Kk3EiccOfnaQiMe48fIzKc84FXPp3z3KB779GG2dvWGUKCIRks0efRL4Y3d/D3AOcK2ZLRyyzaXAguCxCrgFwMziwM1B/0LgimFeK8AnGuay6frlfO79JxzSvv7NNgBufqyJW371RhiliUieGzXo3X2Huz8XLLcDrwD1QzZbAfzA054BasxsFrAUaHL3Le7eC9wTbCsjWDS35pD1ppYOAG58aDPfevBVkv2pEKoSkXx2VMfozWw+cBbw7JCuemBbxnpz0DZS+3DvvcrMGs2ssbW19WjKipTLzpzFDz+3dHD9P5/fzp6Mwzedk+iCMBHJD1kHvZlVAvcCf+Tu+4d2D/MSP0L74Y3ut7l7g7s31NUNe9vDgmBmfGDBwfG/+m47n71z/eB6R8bFVyIi2cgq6M0sQTrk73L3+4bZpBmYm7E+B9h+hHYZxcAplwAbt+0dXO7oVtCLyNHJ5qwbA74PvOLu3xlhszXAVcHZN+cA+9x9B7AeWGBmJ5hZMbAy2FZG8ZWLF7Dhf18MwMJZUwbbO3r6BpfbOnvp7tOhHBE5smz26M8HPg1cZGYbg8dHzOwaM7sm2GYtsAVoAv4f8CUAd08CXwYeIv0h7o/d/eVcDyKqplWWcOU589i04+CRsv0HDu7RL7nhET57x/rhXioiMmjUKRDc/QlGmTDG0zOjXTtC31rSvwjkGPzJJafys43baQ8O2bS0dwPwyKadADy9ZXdotYlIftCVsZNcTXkx933xvMGLqR5+eSfPbNnNF35w7LN7ikhhUdDngQUzq9h0/XKOm1LKo6+28JV7NBeOiGRPQZ9HfnLNuQDs3N8TciUikk8U9Hlkbm05N6w4PewyRCTPKOjzzKfPnc+fLT/1kLauXp1bLyIjU9DnoS9dcPIh6x+7+cmQKhGRfKCgz1O3X3Xw1pCv7exg3da2EKsRkclMQZ+nLl44k3Xf+NDg+q2/1hTGIjI8BX0em1FVyl2fPxuAX77aor16ERmWgj7PnX/ydL57xVkAXHn7s5qvXkQOo6CPgI8uSt9ovLc/xV/+7KWQqxGRyUZBHxHf+eQiAO5et42mlvaQqxGRyURBHxEfP6uec06sBeDi7zwecjUiMpko6CPCzLjx8kWD67s6NE2CiKQp6CNkbm05V517PAAN/+cXIVcjIpOFgj5irjp3/uDy46+18tDL74ZXjIhMCgr6iDl5RiUXnTYDgKtWr+MPfrgh5IpEJGzZ3DN2tZm1mNmw5+2Z2Z9m3GLwJTPrN7PaoO9NM3sx6NOdMibIly44KewSRGQSyWaP/k5g+Uid7n6juy9298XA14Ffu3vmJZoXBv0Nw7+D5Nr7jp8adgkiMomMGvTu/jiQ7bX1VwB3j6kiGTMz446rf2tw/bWdOq9epJDl7Bi9mZWT3vO/N6PZgYfNbIOZrRrl9avMrNHMGltbW3NVVsG68LQZfLJhDgA3/HxTyNWISJhy+WHsZcCTQw7bnO/uS4BLgWvN7IMjvdjdb3P3BndvqKury2FZhetrH07foKQoZiFXIiJhymXQr2TIYRt33x48twD3A0tz+PVkFMdVl7Ji8Wwa39pDnyY7EylYOQl6M6sGlgE/y2irMLOqgWXgEkAzbk2wi06bQXt3krueeSvsUkQkJNmcXnk38DRwqpk1m9nnzOwaM7smY7OPAw+7e2dG20zgCTN7HlgH/Je7P5jL4mV0KxbXc9a8Gn6ooBcpWEWjbeDuV2SxzZ2kT8PMbNsCLBpue5lYl773OP5u7au0dfZSW1EcdjkiMsF0ZWwBOKO+BoB1W3eHW4iIhEJBXwDeWz8FgGv+/bmQKxGRMCjoC0BVaYKq0lGP0olIRCnoC8TV580nHjNSKQ+7FBGZYAr6AjG1vJj+lLOzvTvsUkRkginoC8Tx08oBWLNxe8iViMhEU9AXiIE56nd39oZciYhMNAV9gTBLz3dz2+Nb6NdxepGCoqAvQC+9sy/sEkRkAinoC8g/r1wMwNNbdOGUSCFR0BeQFYvrWTCjksdf03z/IoVEQV9gLl44k2e3trGvqy/sUkRkgijoC8yyU+roTznPbdsTdikiMkEU9AWmvqYMgNb2npArEZGJoqAvMNMrSwAFvUghUdAXmLLiOBXFcQW9SAFR0BegubXlvLW7c/QNRSQSFPQF6KS6SrbsUtCLFIps7hm72sxazGzYG3ub2QVmts/MNgaPb2b0LTezzWbWZGbX5bJwOXYn1lXw1u4ufvO2zrwRKQTZ7NHfCSwfZZv/dvfFweN6ADOLAzcDlwILgSvMbOFYipXcmDmlFID/ectTIVciIhNh1KB398eBtmN476VAk7tvcfde4B5gxTG8j+TY5e+bQzxm1FWVhF2KiEyAXB2jP9fMnjezB8zs9KCtHtiWsU1z0DYsM1tlZo1m1tjaqkv0x1NpIs7vnz+fvV19uGsmS5Goy0XQPwcc7+6LgH8Bfhq02zDbjpgq7n6buze4e0NdXV0OypIjqa0ooSeZoqu3P+xSRGScjTno3X2/u3cEy2uBhJlNJ70HPzdj0zmAbm80ScyqTh+nf2fvgZArEZHxNuagN7PjLLirhZktDd5zN7AeWGBmJ5hZMbASWDPWrye5cepxVQC8smN/yJWIyHgrGm0DM7sbuACYbmbNwF8BCQB3vxW4HPiimSWBA8BKTx/4TZrZl4GHgDiw2t1fHpdRyFE7qa6Sopjx6rvt+oRcJOJGDXp3v2KU/puAm0boWwusPbbSZDwVF8U4qa6S13e2h12KiIwzXRlbwGZWl2rOG5ECoKAvYNMritnV0Rt2GSIyzhT0BWxaZTG7O3tIpXQuvUiUKegL2BlzaujuS3H7E1vCLkVExpGCvoBdduYsTpxewW/e3ht2KSIyjhT0BczMmFZZzF7dKFwk0hT0Ba66rJg9XfpAViTKFPQFbmp5gn0HtEcvEmUK+gI3tUJ79CJRp6AvcNVlCbr7UnT3aRZLkahS0Be4qeXFAPpAViTCFPQFrqY8AaDDNyIRpqAvcNMq0nv0uzo0541IVCnoC9yc2nIAmvfoBiQiUaWgL3Azq0ooihnNe7rCLkVExomCvsAVxWPMqillW5v26EWiSkEvzKkpZ5v26EUia9SgN7PVZtZiZi+N0P8pM3sheDxlZosy+t40sxfNbKOZNeaycMmdhbOn8PI7+9mnUyxFIimbPfo7geVH6N8KLHP3M4EbgNuG9F/o7ovdveHYSpTx9vGz6untT/FfL+4IuxQRGQejBr27Pw60HaH/KXffE6w+A8zJUW0yQU6fPYW5tWU8trkl7FJEZBzk+hj954AHMtYdeNjMNpjZqiO90MxWmVmjmTW2trbmuCw5EjPj1JlTdIqlSETlLOjN7ELSQf/nGc3nu/sS4FLgWjP74Eivd/fb3L3B3Rvq6upyVZZkqbYiQVunLpoSiaKcBL2ZnQncDqxw990D7e6+PXhuAe4Hlubi60nu1VaU0NbZi7vuHysSNWMOejObB9wHfNrdX8torzCzqoFl4BJg2DN3JHzTKorp63fae5JhlyIiOVY02gZmdjdwATDdzJqBvwISAO5+K/BNYBrwPTMDSAZn2MwE7g/aioAfufuD4zAGyYHaYM6bto5eppQmQq5GRHJp1KB39ytG6f888Plh2rcAiw5/hUxGtZXpoN/d2cv86RUhVyMiuaQrYwWA2mBe+rZOTVcsEjUKegEOHrrZo6AXiRwFvQAwLePQjYhEi4JeACgvTn9c860HXw25EhHJNQW9HGZbm2ayFIkSBb0M+h9L6gH4zB3r6E/pwimRqFDQy6B//ET6bNgtrZ08+srOkKsRkVxR0Mug4OI2AIridoQtRSSfKOhlWKlU2BWISK4o6OUQd3/hHAA6ezXnjUhUKOjlEMdPKwfgQG9/yJWISK4o6OUQFcH59J0KepHIUNDLIcpL4gB0abpikchQ0MshEvEYU8sT7NjfHXYpIpIjCno5zLxpFby1uzPsMkQkRxT0cpg5U8t4smk3P27cFnYpIpIDCno5TF1lCQB/9h8vhFyJiOSCgl4OU1dVMrjcm9SVUyL5btSgN7PVZtZiZsPe2NvSvmtmTWb2gpktyehbbmabg77rclm4jJ8ZGUHfqbNvRPJeNnv0dwLLj9B/KbAgeKwCbgEwszhwc9C/ELjCzBaOpViZGKfPrh5c7lDQi+S9UYPe3R8H2o6wyQrgB572DFBjZrOApUCTu29x917gnmBbmeQWzp7ClefMA+Azq9eFXI2IjFUujtHXA5mnZzQHbSO1D8vMVplZo5k1tra25qAsGYtLFh4HwJZdnST7dZxeJJ/lIuiHm8/Wj9A+LHe/zd0b3L2hrq4uB2XJWHRlTIHw69f0i1cknxXl4D2agbkZ63OA7UDxCO2SBz6wYPrgsutmUyJ5LRd79GuAq4Kzb84B9rn7DmA9sMDMTjCzYmBlsK3kgYqSIn7xtWWApiwWyXej7tGb2d3ABcB0M2sG/gpIALj7rcBa4CNAE9AFfDboS5rZl4GHgDiw2t1fHocxyDipLEn/9/jKPRv50HtmDq6LSH4Z9SfX3a8Ypd+Ba0foW0v6F4HkoYGZLAE2vr2X92cczhGR/KErY2VE5YmDQZ/QPWRF8paCXkZUFD/430PH6UXyl4JejugfPrEIgI4e3XFKJF8p6OWI3n9y+rh8R7f26EXylYJejqiyNP15/f7uPvpTOqFeJB8p6OWIKkuKmFZRzN8/8Cpn/92jYZcjIsdAQS+jWjS3BoBdHT2an14kDynoZVQfP+vgXHQHevWhrEi+UdDLqJadenCSOZ1mKZJ/FPQyqimlCf7lirMA6FLQi+QdBb1kpSKYDmH1k2+GW4iIHDUFvWSlLJE+zfJHz74dciUicrQU9JKVedPKB5dTOp9eJK8o6CUr9TVlLD89fXvBdt0wXCSvKOglaxcvnAnAxm17wy1ERI6Kgl6yNq2iGIDPrF7HtraukKsRkWwp6CVrZ82rGVxu6+wNrxAROSpZBb2ZLTezzWbWZGbXDdP/p2a2MXi8ZGb9ZlYb9L1pZi8GfY25HoBMnJry4sFlXTglkj+yuWdsHLgZ+DDQDKw3szXuvmlgG3e/Ebgx2P4y4Kvu3pbxNhe6+66cVi6h2n+gL+wSRCRL2ezRLwWa3H2Lu/cC9wArjrD9FcDduShOJp/H/uQCAPZ0KehF8kU2QV8PbMtYbw7aDmNm5cBy4N6MZgceNrMNZrZqpC9iZqvMrNHMGltbW7MoS8IwZ2oZZYk4m99tD7sUEclSNkE/3F2hR7pi5jLgySGHbc539yXApcC1ZvbB4V7o7re5e4O7N9TV1Q23iUwCiXiM02dPYdOO/WGXIiJZyibom4G5GetzgO0jbLuSIYdt3H178NwC3E/6UJDksZlTSmlt7wm7DBHJUjZBvx5YYGYnmFkx6TBfM3QjM6sGlgE/y2irMLOqgWXgEuClXBQu4dnV0cPWXZ08/cbusEsRkSyMGvTungS+DDwEvAL82N1fNrNrzOyajE0/Djzs7p0ZbTOBJ8zseWAd8F/u/mDuypcwzJxSCsCzWxX0Ivlg1NMrAdx9LbB2SNutQ9bvBO4c0rYFWDSmCmXSueFj72XN89spLtL1diL5QD+pctSqyxKUFMXYq1MsRfKCgl6OydTyYv79mbfY0toRdikiMgoFvRyTmvIEXb39XPSPvw67FBEZhYJejknMhru8QkQmIwW9HJOU6y5TIvlCQS/HpK6qJOwSRCRLCno5Jn908SmDy79+TXMTiUxmCno5Ju87fipXnzcfgD/4oW4zIDKZKejlmPWn0sfpK4qzuu5OREKioJdj9tUPpw/fZN5icCSb321n/Ztto24nIrmnoJdjVhvcLPwXr7Swfe+BI2772//0OJ+49emJKEtEhlDQS06c9/e/xHXKpcikpKCXMfnXT79vcLknmQqxEhEZiYJexmTx3JrB5a7e/vAKEZERKehlTGZOKeXaC08CoLMnGXI1IjIcBb2M2emzqwHt0YtMVgp6GbPy4jgAnb3aoxeZjBT0MmblwQVTNz64OeRKRGQ4WQW9mS03s81m1mRm1w3Tf4GZ7TOzjcHjm9m+VvLfqTOrAHh6y24e2bQz5GpEZKhRg97M4sDNwKXAQuAKM1s4zKb/7e6Lg8f1R/layWPV5Qk+vHAmAF/4gea9EZlsstmjXwo0ufsWd+8F7gFWZPn+Y3mt5JE5U8sGl7e1dYVYiYgMlU3Q1wPbMtabg7ahzjWz583sATM7/Shfi5mtMrNGM2tsbdW0t/mmtrx4cPnhIxy+SaV09azIRMsm6Ie7Z9zQn9bngOPdfRHwL8BPj+K16Ub329y9wd0b6urqsihLJpPpGTciad4z8h59b7+unhWZaNkEfTMwN2N9DrA9cwN33+/uHcHyWiBhZtOzea1Ewycb5vJvv7+UBTMqeWfPyBOcfevBVyewKhGB7IJ+PbDAzE4ws2JgJbAmcwMzO84sfbdoM1savO/ubF4r0RCPGctOqaN+ahnb940c9Pf/5p0JrEpEAEa9Y4S7J83sy8BDQBxY7e4vm9k1Qf+twOXAF80sCRwAVnp6KsNhXztOY5FJoL6mjBea9x3SljmrZX//weW9Xb3EY0ZVaWLC6hMpRFndGig4HLN2SNutGcs3ATdl+1qJrtk1ZbR19tLVmxy8kCpzVsv2jPlwFl//CHVVJaz/xsUTXqdIIdGVsZJT9TXp0yy37+0ebOvuO3QOnD2dvYN7+a3tPRNXnEiBUtBLTtUH59Nfefuzg21D56nf3dnL7s7eCa1LpJDprs6SU6cdl54O4d393bR391FVmjhsj353Rw823Im3IjIutEcvOVVVmuB7n1oCwIqbn6SzJ0l3X3qP/g8vOhlI79Hr6lmRiaOgl5w776RpAGxp7eRXm1sH9+gHjt/v7ujh6jvWD25/97q3J75IkQKioJecqykv5umvX0RxPMYjm94dDPpZQdDf9FjTIdt//b4XJ7xGkUKioJdxMau6jEtOn8l/vrCDd/enz8CpCG5QsnP/4WfaHNDdqUTGjYJexs3vnDGL/pTz+Gu7AJheWXLYNn/z0fT8dy3t3Yf1iUhuKOhl3MyYUgrAvc81AzCrppR7v3guABe/ZwZv/v3vsGBGJQBbdnWGU6RIAdDplTJuZteUDi4vmFFJSVGcJfOm8s8rF3PW3KkAnDVvKsVFMZ54fRcXnjojrFJFIk1BL+NmVnUZv/jaMnZ39DBvWjkAZsaKxQdvSVBWHOecE6fxq80t/OXv6uZjIuNBh25kXJ08o5KzT5zGrOqyEbe54JQ63mjtZMNbbRNYmUjhUNBL6D7RMIfy4jj/9IvXdQcqkXGgoJfQVZUmOKO+mv9+fRffeeS1sMsRiRwFvUwK3778TCB9MdW6rW10ZExnLCJjo6CXSeH4aRWUBxdUffJfn+YvdLWsSM4o6GXSuOPq3xpcXvP8ds77v4+S1M3ERcYsq6A3s+VmttnMmszsumH6P2VmLwSPp8xsUUbfm2b2opltNLPGXBYv0XL2ienJ0GZUpa+g3b6vmz+/V3v2ImM16nn0ZhYHbgY+DDQD681sjbtvythsK7DM3feY2aXAbcDZGf0XuvuuHNYtEdX0t5diZtzw803c+dSb3PtcM2+3dfKjL5xDIq4/QEWORTY/OUuBJnff4u69wD3AiswN3P0pd98TrD4DzMltmVIoiuIx4jHjrz96+uBNTNa/uYcF33iAxza3hFydSH7KJujrgW0Z681B20g+BzyQse7Aw2a2wcxWjfQiM1tlZo1m1tja2ppFWRJ1P7nmXH75x8s4ffYUAD57x3oW/c3DrLztaX76m3doCWbF3N/dx1/+9CWeatIfjSLDsYGbNI+4gdkngN92988H658Glrr7Hw6z7YXA94D3u/vuoG22u283sxnAI8AfuvvjR/qaDQ0N3tiow/mS5u5c+6PnWPviu8P2F8dj9AYf2n7q7HmcUV/Nx86qpzQRn8gyRUJlZhvcvWG4vmzmumkG5maszwG2D/NFzgRuBy4dCHkAd98ePLeY2f2kDwUdMehFMpkZN//eEpr3HOC46lJa23v4SWMzdzy1lb1dffT2p5hRVUJtRTH3PfcOdz37Nn9x/4ucMrMKdzh5ZiXLTqljS2sn699sY9kpdcyqLuXJpl0kU07D8VOZXlXCrOpSasqLKU3E6U2mDj76++no6eeFbXs59bgqjp9WQf3UMsoT8cF731rGTXDdnb5+pyfZT0VxEbGYsberl7fbujhuSikliTiliRglRfHB7VMOyVSKVCpz3IcvG5axfPBrW8Z2NswNeQd26IbrA0j2p+hJph+t7T1s3dXJ6zvbmVldyv4DfTz6SguXnD6TW371Br39KS47czbvrZ/CvNoK5kwtY0ppgoqSOEWT4HMUd6e1I33Pg7rKkhHHPB5fd9+BPh5/fRdzppZx0vRK7lr3FvsPJKkuS1AUM7bvO8DiuTWcOaeGqeUJSorilBTFiMXGt8Zs9uiLgNeADwHvAOuB33P3lzO2mQf8ErjK3Z/KaK8AYu7eHiw/Alzv7g8e6Wtqj16y1d7dR3dfirrgTJ29Xb3c9ezbbNqxn9b9PVSVFrFuaxvtPUniMeP42vKcT4lsBol4jKKYkUw5yf4UAzM5JOJGyqF/mKkdioIf7uQETfsQjxnlxXHKi+O4Q0/wi6wn2U+uSiiKGcVF6e9FUTxGyh13SLlD8Oykf0mVJOLEY4a7059KP1KerrO4KEZxPBb84kq/98AvOSN9PDjlPviLMeUePKCjO8mB4K5mJUUxKkqK0u+dcvrdScRjJOJGzIx47OBj4H3TpabrzoxHD2o/pD/o6+tP0d6dHPbfORvF8RgliRizqkt5+KvLjuk9xrRH7+5JM/sy8BAQB1a7+8tmdk3QfyvwTWAa8L3gt2cy+IIzgfuDtiLgR6OFvMjRqCpNUHVwNmRqyou59sKTD9mmu6+fnfu7qSkvproswZbWDrr7UsybVk5nT5Kd+7vp6u2nJ5lib1cv3X39QdDE089B6JxYV8G2ti527Otm254uepOpwRDr608HfDxuJGIxShPp1+3p6sOAqeXFzK0tY/veblLu9CRTdPQkMdLhGI/FKArCx+xgmMDhYZPZ5ocsH+zzgy+AYI8/mUrR1dtPV08/sdhAuMTTz8E4S4pilBXHmT+tgtOOm8L2fQeIx4x5teU0tXRwUl0lRXEjbsbW3Z1sa+uitb2H/d1JOrqTdCf76UumSAbBbQaxIKljZgzsuKac4BeMDwZuut/oT6Xo7U/Rm/SD4+XQoI0Ff7kMvH8seDYzyhJx5k9Pz5a6ra2L7r5UOsgN4mb09afoGwj+gUfGN9nI+CvJDv8ryjLaBtoT8RhTShNUlRZx+uxqWju6adnfw7zacs6YU01PMsWu9h5qK4rp6EnS1NJBR0+SnmSK7r7+weeBv/JybdQ9+jBoj15E5OgcaY8+/ANqIiIyrhT0IiIRp6AXEYk4Bb2ISMQp6EVEIk5BLyIScQp6EZGIU9CLiETcpLxgysxagbeO8eXTgUKbxlBjLgwac/SNZbzHu3vdcB2TMujHwswaR7o6LKo05sKgMUffeI1Xh25ERCJOQS8iEnFRDPrbwi4gBBpzYdCYo29cxhu5Y/QiInKoKO7Ri4hIBgW9iEjERSbozWy5mW02syYzuy7senLFzOaa2WNm9oqZvWxmXwnaa83sETN7PXiemvGarwffh81m9tvhVT82ZhY3s9+Y2c+D9UiP2cxqzOw/zOzV4N/73AIY81eD/9cvmdndZlYatTGb2WozazGzlzLajnqMZvY+M3sx6PuuHc3NcN097x+kb3H4BnAiUAw8DywMu64cjW0WsCRYriJ9/96FwLeB64L264BvBcsLg/GXACcE35d42OM4xrF/DfgR8PNgPdJjBv4N+HywXAzURHnMQD2wFSgL1n8MXB21MQMfBJYAL2W0HfUYgXXAuaTvaPgAcGm2NURlj34p0OTuW9y9F7gHWBFyTTnh7jvc/blguR14hfQPyArSwUDw/LFgeQVwj7v3uPtWoIn09yevmNkc4HeA2zOaIztmM5tCOhC+D+Duve6+lwiPOVAElJlZEVAObCdiY3b3x4G2Ic1HNUYzmwVMcfenPZ36P8h4zaiiEvT1wLaM9eagLVLMbD5wFvAsMNPdd0D6lwEwI9gsKt+LfwL+DEhltEV5zCcCrcAdweGq282sggiP2d3fAf4BeBvYAexz94eJ8JgzHO0Y64Ploe1ZiUrQD3esKlLnjZpZJXAv8Efuvv9Imw7TllffCzP7XaDF3Tdk+5Jh2vJqzKT3bJcAt7j7WUAn6T/pR5L3Yw6OS68gfYhiNlBhZlce6SXDtOXVmLMw0hjHNPaoBH0zMDdjfQ7pPwEjwcwSpEP+Lne/L2jeGfw5R/DcErRH4XtxPvBRM3uT9GG4i8zs34n2mJuBZnd/Nlj/D9LBH+UxXwxsdfdWd+8D7gPOI9pjHnC0Y2wOloe2ZyUqQb8eWGBmJ5hZMbASWBNyTTkRfLL+feAVd/9ORtca4DPB8meAn2W0rzSzEjM7AVhA+kOcvOHuX3f3Oe4+n/S/5S/d/UqiPeZ3gW1mdmrQ9CFgExEeM+lDNueYWXnw//xDpD+DivKYBxzVGIPDO+1mdk7wvboq4zWjC/sT6Rx+sv0R0mekvAF8I+x6cjiu95P+E+0FYGPw+AgwDXgUeD14rs14zTeC78NmjuKT+cn4AC7g4Fk3kR4zsBhoDP6tfwpMLYAx/w3wKvAS8EPSZ5tEaszA3aQ/g+gjvWf+uWMZI9AQfJ/eAG4imNkgm4emQBARibioHLoREZERKOhFRCJOQS8iEnEKehGRiFPQi4hEnIJeRCTiFPQiIhH3/wGMFPmYzJqRCQAAAABJRU5ErkJggg==", 91 | "text/plain": [ 92 | "
" 93 | ] 94 | }, 95 | "metadata": { 96 | "needs_background": "light" 97 | }, 98 | "output_type": "display_data" 99 | } 100 | ], 101 | "source": [ 102 | "plt.plot(losses)" 103 | ] 104 | }, 105 | { 106 | "cell_type": "code", 107 | "execution_count": null, 108 | "metadata": {}, 109 | "outputs": [], 110 | "source": [] 111 | } 112 | ], 113 | "metadata": { 114 | "kernelspec": { 115 | "display_name": "research-dev", 116 | "language": "python", 117 | "name": "python3" 118 | }, 119 | "language_info": { 120 | "codemirror_mode": { 121 | "name": "ipython", 122 | "version": 3 123 | }, 124 | "file_extension": ".py", 125 | "mimetype": "text/x-python", 126 | "name": "python", 127 | "nbconvert_exporter": "python", 128 | "pygments_lexer": "ipython3", 129 | "version": "3.10.4" 130 | }, 131 | "orig_nbformat": 4, 132 | "vscode": { 133 | "interpreter": { 134 | "hash": "bb52c3a3fbd28bed2787dfc92005f3c85b155acf67ba344390d117213f233e38" 135 | } 136 | } 137 | }, 138 | "nbformat": 4, 139 | "nbformat_minor": 2 140 | } 141 | -------------------------------------------------------------------------------- /examples/2-gpt.ipynb: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/myazdani/formal-algorithms-for-transformers/35c7345e16daae4ccbe1d3030758c5333027c2b1/examples/2-gpt.ipynb -------------------------------------------------------------------------------- /examples/adder/data.py: -------------------------------------------------------------------------------- 1 | """ 2 | Dataset for training an adder for a GPT-style architecture. 3 | Shamelessly taken from: https://github.com/karpathy/minGPT 4 | """ 5 | 6 | import os 7 | import sys 8 | import json 9 | 10 | import torch 11 | from torch.utils.data import Dataset 12 | from torch.utils.data.dataloader import DataLoader 13 | 14 | 15 | class AdditionDataset(Dataset): 16 | """ 17 | Creates n-digit addition problems. For example, if n=2, then an example 18 | addition problem would be to add 85 + 50 = 135. This problem would be 19 | represented as the following string for the GPT: 20 | 21 | "8550531" 22 | 23 | This is because: 24 | - we are discarding the + and =, which are not necessary. We just encode the digits 25 | of the input numbers concatenated together. 26 | - the result 135 is encoded backwards to make the addition easier to learn for the 27 | GPT model, because of how the addition algorithm works. 28 | 29 | As one more example, the problem 6 + 39 = 45 would be encoded as: 30 | 31 | "0639054" 32 | 33 | where you will notice that we are padding with zeros to make sure that we always 34 | produce strings of the exact same size: n + n + (n + 1). When n=2, this is 7. 35 | At test time, we will feed in an addition problem by giving the first 2n digits, 36 | and hoping that the GPT model completes the sequence with the next (n+1) digits 37 | correctly. 38 | """ 39 | 40 | def __init__(self, ndigit, split): 41 | self.split = split # train/test 42 | 43 | # split up all addition problems into either training data or test data 44 | self.ndigit = ndigit 45 | assert ndigit <= 3, "the lines below would be very memory inefficient, in future maybe refactor to support" 46 | num = (10**ndigit)**2 # total number of possible addition problems with ndigit numbers 47 | rng = torch.Generator() 48 | perm = torch.randperm(num, generator=rng) 49 | num_test = min(int(num*0.2), 500) # 20% of the whole dataset, or only up to 500 50 | self.ixes = perm[:num_test] if split == 'test' else perm[num_test:] 51 | 52 | def get_vocab_size(self): 53 | return 10 # digits 0..9 54 | 55 | def get_block_size(self): 56 | # a,b,a+b, and +1 due to potential carry overflow, 57 | # but then also -1 because very last digit doesn't ever plug back 58 | # as there is no explicit token to predict, it is implied 59 | return 3*self.ndigit + 1 - 1 60 | 61 | def __len__(self): 62 | return self.ixes.nelement() 63 | 64 | def __getitem__(self, idx): 65 | ndigit = self.ndigit 66 | # given a problem index idx, first recover the associated a + b 67 | idx = self.ixes[idx].item() 68 | nd = 10**ndigit 69 | a = idx // nd 70 | b = idx % nd 71 | # calculate the "label" of the addition problem a + b 72 | c = a + b 73 | # encode the digits of a, b, c into strings 74 | astr = f'%0{ndigit}d' % a 75 | bstr = f'%0{ndigit}d' % b 76 | cstr = (f'%0{ndigit+1}d' % c)[::-1] # reverse c to make addition easier 77 | render = astr + bstr + cstr 78 | dix = [int(s) for s in render] # convert each character to its token index 79 | # x will be input to GPT and y will be the associated expected outputs 80 | x = torch.tensor(dix[:-1], dtype=torch.long) 81 | y = torch.tensor(dix[1:], dtype=torch.long) # predict the next token in the sequence 82 | y[:ndigit*2-1] = -1 # we will only train in the output locations. -1 will mask loss to zero 83 | return x, y 84 | 85 | 86 | 87 | -------------------------------------------------------------------------------- /modulo_math.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": { 6 | "id": "view-in-github", 7 | "colab_type": "text" 8 | }, 9 | "source": [ 10 | "\"Open" 11 | ] 12 | }, 13 | { 14 | "cell_type": "code", 15 | "execution_count": 1, 16 | "metadata": { 17 | "colab": { 18 | "base_uri": "https://localhost:8080/" 19 | }, 20 | "id": "MXYOGZQOMeY2", 21 | "outputId": "54405a3f-213e-435a-ac0e-7e7da64f054c" 22 | }, 23 | "outputs": [ 24 | { 25 | "output_type": "stream", 26 | "name": "stdout", 27 | "text": [ 28 | "Wed May 3 05:49:56 2023 \n", 29 | "+-----------------------------------------------------------------------------+\n", 30 | "| NVIDIA-SMI 525.85.12 Driver Version: 525.85.12 CUDA Version: 12.0 |\n", 31 | "|-------------------------------+----------------------+----------------------+\n", 32 | "| GPU Name Persistence-M| Bus-Id Disp.A | Volatile Uncorr. ECC |\n", 33 | "| Fan Temp Perf Pwr:Usage/Cap| Memory-Usage | GPU-Util Compute M. |\n", 34 | "| | | MIG M. |\n", 35 | "|===============================+======================+======================|\n", 36 | "| 0 NVIDIA A100-SXM... Off | 00000000:00:04.0 Off | 0 |\n", 37 | "| N/A 30C P0 45W / 400W | 0MiB / 40960MiB | 0% Default |\n", 38 | "| | | Disabled |\n", 39 | "+-------------------------------+----------------------+----------------------+\n", 40 | " \n", 41 | "+-----------------------------------------------------------------------------+\n", 42 | "| Processes: |\n", 43 | "| GPU GI CI PID Type Process name GPU Memory |\n", 44 | "| ID ID Usage |\n", 45 | "|=============================================================================|\n", 46 | "| No running processes found |\n", 47 | "+-----------------------------------------------------------------------------+\n" 48 | ] 49 | } 50 | ], 51 | "source": [ 52 | "!nvidia-smi" 53 | ] 54 | }, 55 | { 56 | "cell_type": "code", 57 | "execution_count": 2, 58 | "metadata": { 59 | "colab": { 60 | "base_uri": "https://localhost:8080/" 61 | }, 62 | "id": "2aio8SkZRCz8", 63 | "outputId": "af58f121-db78-4a3f-d012-fe2396a20479" 64 | }, 65 | "outputs": [ 66 | { 67 | "output_type": "stream", 68 | "name": "stdout", 69 | "text": [ 70 | "fatal: destination path 'formal-algorithms-for-transformers' already exists and is not an empty directory.\n" 71 | ] 72 | } 73 | ], 74 | "source": [ 75 | "!git clone https://github.com/myazdani/formal-algorithms-for-transformers.git" 76 | ] 77 | }, 78 | { 79 | "cell_type": "code", 80 | "execution_count": 3, 81 | "metadata": { 82 | "colab": { 83 | "base_uri": "https://localhost:8080/" 84 | }, 85 | "id": "oBNiX3vSR1Ss", 86 | "outputId": "4261a54b-9369-40d9-fba9-e259212f2d4c" 87 | }, 88 | "outputs": [ 89 | { 90 | "output_type": "stream", 91 | "name": "stdout", 92 | "text": [ 93 | "--2023-05-03 05:50:01-- https://raw.githubusercontent.com/myazdani/regularized-t-learner/main/requirements.txt\n", 94 | "Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.108.133, 185.199.109.133, 185.199.110.133, ...\n", 95 | "Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.108.133|:443... connected.\n", 96 | "HTTP request sent, awaiting response... 200 OK\n", 97 | "Length: 107 [text/plain]\n", 98 | "Saving to: ‘reqs.txt’\n", 99 | "\n", 100 | "reqs.txt 100%[===================>] 107 --.-KB/s in 0s \n", 101 | "\n", 102 | "2023-05-03 05:50:01 (7.32 MB/s) - ‘reqs.txt’ saved [107/107]\n", 103 | "\n" 104 | ] 105 | } 106 | ], 107 | "source": [ 108 | "#!pip install -r ./formal-algorithms-for-transformers/requirements.txt\n", 109 | "!wget -O reqs.txt https://raw.githubusercontent.com/myazdani/regularized-t-learner/main/requirements.txt" 110 | ] 111 | }, 112 | { 113 | "cell_type": "code", 114 | "execution_count": 4, 115 | "metadata": { 116 | "id": "Ppkzb4A8Tjzv" 117 | }, 118 | "outputs": [], 119 | "source": [ 120 | "%%capture\n", 121 | "!pip install -r reqs.txt" 122 | ] 123 | }, 124 | { 125 | "cell_type": "code", 126 | "source": [ 127 | "!pip3 install numpy --pre torch[dynamo] --force-reinstall --extra-index-url https://download.pytorch.org/whl/nightly/cu118" 128 | ], 129 | "metadata": { 130 | "colab": { 131 | "base_uri": "https://localhost:8080/" 132 | }, 133 | "id": "mI9Z-MhFJYY3", 134 | "outputId": "e01780a9-b430-4e32-b11e-984ac75199e3" 135 | }, 136 | "execution_count": 5, 137 | "outputs": [ 138 | { 139 | "output_type": "stream", 140 | "name": "stdout", 141 | "text": [ 142 | "Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/, https://download.pytorch.org/whl/nightly/cu118\n", 143 | "Collecting numpy\n", 144 | " Using cached numpy-1.24.3-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (17.3 MB)\n", 145 | "Collecting torch[dynamo]\n", 146 | " Using cached https://download.pytorch.org/whl/nightly/cu118/torch-2.1.0.dev20230501%2Bcu118-cp310-cp310-linux_x86_64.whl (2268.2 MB)\n", 147 | "Collecting pytorch-triton==2.1.0+7d1a95b046\n", 148 | " Using cached https://download.pytorch.org/whl/nightly/pytorch_triton-2.1.0%2B7d1a95b046-cp310-cp310-linux_x86_64.whl (88.8 MB)\n", 149 | "Collecting typing-extensions\n", 150 | " Using cached typing_extensions-4.5.0-py3-none-any.whl (27 kB)\n", 151 | "Collecting filelock\n", 152 | " Using cached filelock-3.12.0-py3-none-any.whl (10 kB)\n", 153 | "Collecting fsspec\n", 154 | " Using cached https://download.pytorch.org/whl/nightly/fsspec-2023.4.0-py3-none-any.whl (153 kB)\n", 155 | "Collecting jinja2\n", 156 | " Using cached https://download.pytorch.org/whl/nightly/Jinja2-3.1.2-py3-none-any.whl (133 kB)\n", 157 | "Collecting sympy\n", 158 | " Using cached sympy-1.12rc1-py3-none-any.whl (5.7 MB)\n", 159 | "Collecting networkx\n", 160 | " Using cached networkx-3.1-py3-none-any.whl (2.1 MB)\n", 161 | "Collecting MarkupSafe>=2.0\n", 162 | " Using cached https://download.pytorch.org/whl/nightly/MarkupSafe-2.1.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (25 kB)\n", 163 | "Collecting mpmath>=0.19\n", 164 | " Using cached mpmath-1.3.0-py3-none-any.whl (536 kB)\n", 165 | "Installing collected packages: mpmath, typing-extensions, sympy, numpy, networkx, MarkupSafe, fsspec, filelock, pytorch-triton, jinja2, torch\n", 166 | " Attempting uninstall: mpmath\n", 167 | " Found existing installation: mpmath 1.3.0\n", 168 | " Uninstalling mpmath-1.3.0:\n", 169 | " Successfully uninstalled mpmath-1.3.0\n", 170 | " Attempting uninstall: typing-extensions\n", 171 | " Found existing installation: typing_extensions 4.5.0\n", 172 | " Uninstalling typing_extensions-4.5.0:\n", 173 | " Successfully uninstalled typing_extensions-4.5.0\n", 174 | " Attempting uninstall: sympy\n", 175 | " Found existing installation: sympy 1.12rc1\n", 176 | " Uninstalling sympy-1.12rc1:\n", 177 | " Successfully uninstalled sympy-1.12rc1\n", 178 | " Attempting uninstall: numpy\n", 179 | " Found existing installation: numpy 1.24.1\n", 180 | " Uninstalling numpy-1.24.1:\n", 181 | " Successfully uninstalled numpy-1.24.1\n", 182 | " Attempting uninstall: networkx\n", 183 | " Found existing installation: networkx 3.1\n", 184 | " Uninstalling networkx-3.1:\n", 185 | " Successfully uninstalled networkx-3.1\n", 186 | " Attempting uninstall: MarkupSafe\n", 187 | " Found existing installation: MarkupSafe 2.1.2\n", 188 | " Uninstalling MarkupSafe-2.1.2:\n", 189 | " Successfully uninstalled MarkupSafe-2.1.2\n", 190 | " Attempting uninstall: fsspec\n", 191 | " Found existing installation: fsspec 2023.4.0\n", 192 | " Uninstalling fsspec-2023.4.0:\n", 193 | " Successfully uninstalled fsspec-2023.4.0\n", 194 | " Attempting uninstall: filelock\n", 195 | " Found existing installation: filelock 3.12.0\n", 196 | " Uninstalling filelock-3.12.0:\n", 197 | " Successfully uninstalled filelock-3.12.0\n", 198 | " Attempting uninstall: pytorch-triton\n", 199 | " Found existing installation: pytorch-triton 2.1.0+7d1a95b046\n", 200 | " Uninstalling pytorch-triton-2.1.0+7d1a95b046:\n", 201 | " Successfully uninstalled pytorch-triton-2.1.0+7d1a95b046\n", 202 | " Attempting uninstall: jinja2\n", 203 | " Found existing installation: Jinja2 3.1.2\n", 204 | " Uninstalling Jinja2-3.1.2:\n", 205 | " Successfully uninstalled Jinja2-3.1.2\n", 206 | " Attempting uninstall: torch\n", 207 | " Found existing installation: torch 2.0.0\n", 208 | " Uninstalling torch-2.0.0:\n", 209 | " Successfully uninstalled torch-2.0.0\n", 210 | "\u001b[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.\n", 211 | "torchvision 0.15.1+cu118 requires torch==2.0.0, but you have torch 2.1.0.dev20230501+cu118 which is incompatible.\n", 212 | "torchtext 0.15.1 requires torch==2.0.0, but you have torch 2.1.0.dev20230501+cu118 which is incompatible.\n", 213 | "torchdata 0.6.0 requires torch==2.0.0, but you have torch 2.1.0.dev20230501+cu118 which is incompatible.\n", 214 | "torchaudio 2.0.1+cu118 requires torch==2.0.0, but you have torch 2.1.0.dev20230501+cu118 which is incompatible.\n", 215 | "tensorflow 2.12.0 requires numpy<1.24,>=1.22, but you have numpy 1.24.3 which is incompatible.\n", 216 | "numba 0.56.4 requires numpy<1.24,>=1.18, but you have numpy 1.24.3 which is incompatible.\n", 217 | "fastai 2.7.12 requires torch<2.1,>=1.7, but you have torch 2.1.0.dev20230501+cu118 which is incompatible.\u001b[0m\u001b[31m\n", 218 | "\u001b[0mSuccessfully installed MarkupSafe-2.1.2 filelock-3.12.0 fsspec-2023.4.0 jinja2-3.1.2 mpmath-1.3.0 networkx-3.1 numpy-1.24.3 pytorch-triton-2.1.0+7d1a95b046 sympy-1.12rc1 torch-2.1.0.dev20230501+cu118 typing-extensions-4.5.0\n" 219 | ] 220 | } 221 | ] 222 | }, 223 | { 224 | "cell_type": "code", 225 | "execution_count": 6, 226 | "metadata": { 227 | "id": "YQ7tOW0mSHW_" 228 | }, 229 | "outputs": [], 230 | "source": [ 231 | "import torch\n", 232 | "import torch.nn as nn\n", 233 | "from torch.nn import functional as F\n", 234 | "import sys\n", 235 | "sys.path.insert(0, './formal-algorithms-for-transformers')\n", 236 | "from src.alg_5 import MHAttentionInefficient\n", 237 | "from src.alg_10 import DTransformer" 238 | ] 239 | }, 240 | { 241 | "cell_type": "code", 242 | "source": [ 243 | "torch.__version__" 244 | ], 245 | "metadata": { 246 | "colab": { 247 | "base_uri": "https://localhost:8080/", 248 | "height": 35 249 | }, 250 | "id": "_VlLVZkQJcXA", 251 | "outputId": "fc8d50a7-3305-416c-9fc4-b5528a7b74d2" 252 | }, 253 | "execution_count": 7, 254 | "outputs": [ 255 | { 256 | "output_type": "execute_result", 257 | "data": { 258 | "text/plain": [ 259 | "'2.1.0.dev20230501+cu118'" 260 | ], 261 | "application/vnd.google.colaboratory.intrinsic+json": { 262 | "type": "string" 263 | } 264 | }, 265 | "metadata": {}, 266 | "execution_count": 7 267 | } 268 | ] 269 | }, 270 | { 271 | "cell_type": "code", 272 | "execution_count": 8, 273 | "metadata": { 274 | "id": "EQOj-gzFUnfQ" 275 | }, 276 | "outputs": [], 277 | "source": [ 278 | "import pytorch_lightning as pl\n" 279 | ] 280 | }, 281 | { 282 | "cell_type": "code", 283 | "execution_count": 9, 284 | "metadata": { 285 | "id": "33PF6V8EWJ92" 286 | }, 287 | "outputs": [], 288 | "source": [ 289 | "from typing import Any" 290 | ] 291 | }, 292 | { 293 | "cell_type": "code", 294 | "execution_count": 10, 295 | "metadata": { 296 | "id": "9lXr78h0UIrV" 297 | }, 298 | "outputs": [], 299 | "source": [ 300 | "class GPT(pl.LightningModule):\n", 301 | "\n", 302 | " def __init__(self, embed_dim, mlp_dim, max_seq_len, L_dec, vocab_size, num_heads,\n", 303 | " grad_glip: float = 0.9, optimizer: str = \"AdamW\", learning_rate: float = 1e-3, \n", 304 | " lr_scheduler: str = None,**kwargs: Any\n", 305 | " ):\n", 306 | " super().__init__()\n", 307 | " self.save_hyperparameters() \n", 308 | " self.model = DTransformer(embed_dim=embed_dim, \n", 309 | " mlp_dim=mlp_dim, \n", 310 | " max_seq_len=max_seq_len,\n", 311 | " L_dec=L_dec, \n", 312 | " vocab_size=vocab_size, \n", 313 | " num_heads=num_heads\n", 314 | " )\n", 315 | "\n", 316 | " self.criterion = nn.NLLLoss()\n", 317 | " \n", 318 | "\n", 319 | " def forward(self, x):\n", 320 | " output = self.model(x)\n", 321 | " return output\n", 322 | "\n", 323 | " def training_step(self, batch, batch_idx):\n", 324 | " x, y = batch\n", 325 | " y_hat = self(x)\n", 326 | " loss = self.criterion(torch.log(y_hat[:,-1,:]), y)\n", 327 | " return loss\n", 328 | "\n", 329 | " def configure_optimizers(self):\n", 330 | " # ref: https://github.com/Lightning-AI/lightning/issues/7576\n", 331 | " optimizer = getattr(torch.optim, self.hparams.optimizer)(\n", 332 | " self.parameters(),\n", 333 | " lr=self.hparams.learning_rate,\n", 334 | " )\n", 335 | " if self.hparams.lr_scheduler is None:\n", 336 | " return optimizer\n", 337 | " scheduler = self.configure_scheduler(optimizer, self.hparams.lr_scheduler)\n", 338 | " return [optimizer], [scheduler] \n", 339 | "\n", 340 | "\n" 341 | ] 342 | }, 343 | { 344 | "cell_type": "code", 345 | "execution_count": 11, 346 | "metadata": { 347 | "id": "CxmqvVeVNb6m" 348 | }, 349 | "outputs": [], 350 | "source": [ 351 | "from torch.utils.data import Dataset\n", 352 | "from torch.utils.data import DataLoader\n", 353 | "class Digits(Dataset):\n", 354 | " def __init__(self, vocab_size, N, max_seq_len):\n", 355 | " super().__init__()\n", 356 | " self.x_ids = torch.randint(0,vocab_size, size = (N, max_seq_len)) \n", 357 | " self.y_ids = self.x_ids.sum(1)%vocab_size\n", 358 | " def __len__(self):\n", 359 | " return len(self.x_ids)\n", 360 | "\n", 361 | " def __getitem__(self, idx):\n", 362 | " return self.x_ids[idx],self.y_ids[idx]" 363 | ] 364 | }, 365 | { 366 | "cell_type": "code", 367 | "execution_count": 12, 368 | "metadata": { 369 | "id": "r1NdK6rZT7LB" 370 | }, 371 | "outputs": [], 372 | "source": [ 373 | "\n", 374 | "tr_loader = DataLoader(Digits(max_seq_len=2, vocab_size=12, N=32), \n", 375 | " batch_size=10000, shuffle=False) \n", 376 | "vl_loader = DataLoader(Digits(max_seq_len=2, vocab_size=12, N=1024), \n", 377 | " batch_size=10000, shuffle=True) " 378 | ] 379 | }, 380 | { 381 | "cell_type": "code", 382 | "execution_count": 44, 383 | "metadata": { 384 | "id": "8pJ3j28CUcSW" 385 | }, 386 | "outputs": [], 387 | "source": [] 388 | }, 389 | { 390 | "cell_type": "code", 391 | "execution_count": 13, 392 | "metadata": { 393 | "id": "2o62LuwxUcJa", 394 | "colab": { 395 | "base_uri": "https://localhost:8080/" 396 | }, 397 | "outputId": "9a8083b9-f1ea-4b75-b0f8-d806c2ddf4fc" 398 | }, 399 | "outputs": [ 400 | { 401 | "output_type": "stream", 402 | "name": "stderr", 403 | "text": [ 404 | "INFO:pytorch_lightning.utilities.rank_zero:GPU available: True (cuda), used: True\n", 405 | "INFO:pytorch_lightning.utilities.rank_zero:TPU available: False, using: 0 TPU cores\n", 406 | "INFO:pytorch_lightning.utilities.rank_zero:IPU available: False, using: 0 IPUs\n", 407 | "INFO:pytorch_lightning.utilities.rank_zero:HPU available: False, using: 0 HPUs\n" 408 | ] 409 | } 410 | ], 411 | "source": [ 412 | "trainer = pl.Trainer(max_epochs=1000)\n", 413 | "modulo_adder = GPT(embed_dim=50, mlp_dim=64, max_seq_len=2, L_dec=5, vocab_size=12, \n", 414 | " num_heads=5)" 415 | ] 416 | }, 417 | { 418 | "cell_type": "code", 419 | "execution_count": 14, 420 | "metadata": { 421 | "colab": { 422 | "base_uri": "https://localhost:8080/", 423 | "height": 408, 424 | "referenced_widgets": [ 425 | "e6ffa4033921427c8f9e7db4a1299e7d", 426 | "89b38d09df7b4bc3b49ed21e6052b9e6", 427 | "d49975885373480ca3833d58bc95a3f9", 428 | "9ab24683ddb345e6abb1b8b7b75076ea", 429 | "d90e979343e94ba9a7907b3400de4aa1", 430 | "5da4e592e09f42cf8f2039b10b527285", 431 | "946b3920189d43b5955184572bd23cbf", 432 | "e147f1498b9d462dadcc82392cec7d30", 433 | "9f50cdd69138403181e4a188bcc1fe8d", 434 | "aa6cd9138fed4f47a86d23e1c6879740", 435 | "293f31185a6c4fe8a68a577b2c39fdeb" 436 | ] 437 | }, 438 | "id": "SQNJ6397Vxv7", 439 | "outputId": "76080597-7ac7-4844-d63e-299f36e88a59" 440 | }, 441 | "outputs": [ 442 | { 443 | "output_type": "stream", 444 | "name": "stderr", 445 | "text": [ 446 | "INFO:pytorch_lightning.utilities.rank_zero:You are using a CUDA device ('NVIDIA A100-SXM4-40GB') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision\n", 447 | "INFO:pytorch_lightning.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]\n", 448 | "INFO:pytorch_lightning.callbacks.model_summary:\n", 449 | " | Name | Type | Params\n", 450 | "-------------------------------------------\n", 451 | "0 | model | DTransformer | 288 K \n", 452 | "1 | criterion | NLLLoss | 0 \n", 453 | "-------------------------------------------\n", 454 | "288 K Trainable params\n", 455 | "0 Non-trainable params\n", 456 | "288 K Total params\n", 457 | "1.156 Total estimated model params size (MB)\n", 458 | "/usr/local/lib/python3.10/dist-packages/pytorch_lightning/trainer/connectors/data_connector.py:430: PossibleUserWarning: The dataloader, train_dataloader, does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 12 which is the number of cpus on this machine) in the `DataLoader` init to improve performance.\n", 459 | " rank_zero_warn(\n", 460 | "/usr/local/lib/python3.10/dist-packages/pytorch_lightning/loops/fit_loop.py:280: PossibleUserWarning: The number of training batches (1) is smaller than the logging interval Trainer(log_every_n_steps=50). Set a lower value for log_every_n_steps if you want to see logs for the training epoch.\n", 461 | " rank_zero_warn(\n" 462 | ] 463 | }, 464 | { 465 | "output_type": "display_data", 466 | "data": { 467 | "text/plain": [ 468 | "Training: 0it [00:00, ?it/s]" 469 | ], 470 | "application/vnd.jupyter.widget-view+json": { 471 | "version_major": 2, 472 | "version_minor": 0, 473 | "model_id": "e6ffa4033921427c8f9e7db4a1299e7d" 474 | } 475 | }, 476 | "metadata": {} 477 | }, 478 | { 479 | "output_type": "stream", 480 | "name": "stderr", 481 | "text": [ 482 | "INFO:pytorch_lightning.utilities.rank_zero:`Trainer.fit` stopped: `max_epochs=1000` reached.\n" 483 | ] 484 | }, 485 | { 486 | "output_type": "stream", 487 | "name": "stdout", 488 | "text": [ 489 | "CPU times: user 1min 43s, sys: 6.32 s, total: 1min 49s\n", 490 | "Wall time: 1min 46s\n" 491 | ] 492 | } 493 | ], 494 | "source": [ 495 | "%%time\n", 496 | "trainer.fit(model=modulo_adder, train_dataloaders=tr_loader, \n", 497 | " #val_dataloaders=vl_loader\n", 498 | " )" 499 | ] 500 | }, 501 | { 502 | "cell_type": "code", 503 | "source": [ 504 | "for batch in tr_loader:\n", 505 | " break\n", 506 | "x, y = batch \n", 507 | "y_hat = modulo_adder(x) \n", 508 | "torch.mean((y_hat[:,-1,:].argmax(1) == y).float())" 509 | ], 510 | "metadata": { 511 | "colab": { 512 | "base_uri": "https://localhost:8080/" 513 | }, 514 | "id": "j--Z9lqdH3ry", 515 | "outputId": "3d95e849-6509-4916-ef17-3bf4a6caf337" 516 | }, 517 | "execution_count": 15, 518 | "outputs": [ 519 | { 520 | "output_type": "execute_result", 521 | "data": { 522 | "text/plain": [ 523 | "tensor(0.9375)" 524 | ] 525 | }, 526 | "metadata": {}, 527 | "execution_count": 15 528 | } 529 | ] 530 | }, 531 | { 532 | "cell_type": "markdown", 533 | "source": [ 534 | "testing to see if torch.compile gives any speedup" 535 | ], 536 | "metadata": { 537 | "id": "WLwPkMj1NFqj" 538 | } 539 | }, 540 | { 541 | "cell_type": "code", 542 | "source": [ 543 | "trainer = pl.Trainer(max_epochs=1000)\n", 544 | "modulo_adder = GPT(embed_dim=50, mlp_dim=64, max_seq_len=2, L_dec=5, vocab_size=12, \n", 545 | " num_heads=5)\n", 546 | "modulo_adder = torch.compile(modulo_adder)" 547 | ], 548 | "metadata": { 549 | "colab": { 550 | "base_uri": "https://localhost:8080/" 551 | }, 552 | "id": "7lZ6-GFTGpO6", 553 | "outputId": "4489fedf-6330-4639-987e-3d338db145ee" 554 | }, 555 | "execution_count": 17, 556 | "outputs": [ 557 | { 558 | "output_type": "stream", 559 | "name": "stderr", 560 | "text": [ 561 | "INFO:pytorch_lightning.utilities.rank_zero:GPU available: True (cuda), used: True\n", 562 | "INFO:pytorch_lightning.utilities.rank_zero:TPU available: False, using: 0 TPU cores\n", 563 | "INFO:pytorch_lightning.utilities.rank_zero:IPU available: False, using: 0 IPUs\n", 564 | "INFO:pytorch_lightning.utilities.rank_zero:HPU available: False, using: 0 HPUs\n" 565 | ] 566 | } 567 | ] 568 | }, 569 | { 570 | "cell_type": "code", 571 | "source": [ 572 | "%%time\n", 573 | "trainer.fit(model=modulo_adder, train_dataloaders=tr_loader, \n", 574 | " #val_dataloaders=vl_loader\n", 575 | " )" 576 | ], 577 | "metadata": { 578 | "colab": { 579 | "base_uri": "https://localhost:8080/", 580 | "height": 373, 581 | "referenced_widgets": [ 582 | "c71c99fce4484f258c79d866ec4c9d0e", 583 | "b4cac92462074a1395355da3e211a4bc", 584 | "beb1a70370214e219ce35561adb8f3a7", 585 | "90b19e6e0d184bc5a40abb337007b767", 586 | "4a00b44c331146f48ea8b032ee5f5d9e", 587 | "faabf22fe42841c0b9769f901ff571ad", 588 | "86d5601def684b9691b8e4167f39b252", 589 | "02d1c82b972e4614a38df08036617218", 590 | "5af51a8b676e46efb3c0d6dbe271a42a", 591 | "d879a22ef90d4545a5f65922e92ce50d", 592 | "106fb922b3d94535974c3aac26436a06" 593 | ] 594 | }, 595 | "id": "pa7e2bBoGpE8", 596 | "outputId": "84ca63b7-7199-4151-db3d-a33475527f46" 597 | }, 598 | "execution_count": 18, 599 | "outputs": [ 600 | { 601 | "output_type": "stream", 602 | "name": "stderr", 603 | "text": [ 604 | "INFO:pytorch_lightning.utilities.rank_zero:You are using a CUDA device ('NVIDIA A100-SXM4-40GB') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision\n", 605 | "INFO:pytorch_lightning.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]\n", 606 | "INFO:pytorch_lightning.callbacks.model_summary:\n", 607 | " | Name | Type | Params\n", 608 | "-------------------------------------------\n", 609 | "0 | model | DTransformer | 288 K \n", 610 | "1 | criterion | NLLLoss | 0 \n", 611 | "-------------------------------------------\n", 612 | "288 K Trainable params\n", 613 | "0 Non-trainable params\n", 614 | "288 K Total params\n", 615 | "1.156 Total estimated model params size (MB)\n" 616 | ] 617 | }, 618 | { 619 | "output_type": "display_data", 620 | "data": { 621 | "text/plain": [ 622 | "Training: 0it [00:00, ?it/s]" 623 | ], 624 | "application/vnd.jupyter.widget-view+json": { 625 | "version_major": 2, 626 | "version_minor": 0, 627 | "model_id": "c71c99fce4484f258c79d866ec4c9d0e" 628 | } 629 | }, 630 | "metadata": {} 631 | }, 632 | { 633 | "output_type": "stream", 634 | "name": "stderr", 635 | "text": [ 636 | "/usr/local/lib/python3.10/dist-packages/torch/_inductor/compile_fx.py:108: UserWarning: TensorFloat32 tensor cores for float32 matrix multiplication available but not enabled. Consider setting `torch.set_float32_matmul_precision('high')` for better performance.\n", 637 | " warnings.warn(\n", 638 | "INFO:pytorch_lightning.utilities.rank_zero:`Trainer.fit` stopped: `max_epochs=1000` reached.\n" 639 | ] 640 | }, 641 | { 642 | "output_type": "stream", 643 | "name": "stdout", 644 | "text": [ 645 | "CPU times: user 1min 59s, sys: 10.1 s, total: 2min 9s\n", 646 | "Wall time: 2min 9s\n" 647 | ] 648 | } 649 | ] 650 | }, 651 | { 652 | "cell_type": "code", 653 | "source": [ 654 | "for batch in tr_loader:\n", 655 | " break\n", 656 | "x, y = batch \n", 657 | "y_hat = modulo_adder(x) \n", 658 | "torch.mean((y_hat[:,-1,:].argmax(1) == y).float())" 659 | ], 660 | "metadata": { 661 | "id": "CY6OUmhqlrd2", 662 | "colab": { 663 | "base_uri": "https://localhost:8080/" 664 | }, 665 | "outputId": "0bbff58e-6dd5-43ca-b70c-f2c1518d3a38" 666 | }, 667 | "execution_count": 19, 668 | "outputs": [ 669 | { 670 | "output_type": "execute_result", 671 | "data": { 672 | "text/plain": [ 673 | "tensor(0.9375)" 674 | ] 675 | }, 676 | "metadata": {}, 677 | "execution_count": 19 678 | } 679 | ] 680 | }, 681 | { 682 | "cell_type": "code", 683 | "source": [], 684 | "metadata": { 685 | "id": "tBUcg058lzG8" 686 | }, 687 | "execution_count": 27, 688 | "outputs": [] 689 | }, 690 | { 691 | "cell_type": "code", 692 | "source": [], 693 | "metadata": { 694 | "id": "AjjWJMyYl09V" 695 | }, 696 | "execution_count": 28, 697 | "outputs": [] 698 | }, 699 | { 700 | "cell_type": "code", 701 | "execution_count": null, 702 | "metadata": { 703 | "id": "gt9plbPDTVFS" 704 | }, 705 | "outputs": [], 706 | "source": [] 707 | } 708 | ], 709 | "metadata": { 710 | "accelerator": "GPU", 711 | "colab": { 712 | "provenance": [], 713 | "machine_shape": "hm", 714 | "gpuType": "A100", 715 | "authorship_tag": "ABX9TyMpjJVJKiWq1+BXtxhsGJMp", 716 | "include_colab_link": true 717 | }, 718 | "gpuClass": "standard", 719 | "kernelspec": { 720 | "display_name": "Python 3", 721 | "name": "python3" 722 | }, 723 | "language_info": { 724 | "name": "python" 725 | }, 726 | "widgets": { 727 | "application/vnd.jupyter.widget-state+json": { 728 | "e6ffa4033921427c8f9e7db4a1299e7d": { 729 | "model_module": "@jupyter-widgets/controls", 730 | "model_name": "HBoxModel", 731 | "model_module_version": "1.5.0", 732 | "state": { 733 | "_dom_classes": [], 734 | "_model_module": "@jupyter-widgets/controls", 735 | "_model_module_version": "1.5.0", 736 | "_model_name": "HBoxModel", 737 | "_view_count": null, 738 | "_view_module": "@jupyter-widgets/controls", 739 | "_view_module_version": "1.5.0", 740 | "_view_name": "HBoxView", 741 | "box_style": "", 742 | "children": [ 743 | "IPY_MODEL_89b38d09df7b4bc3b49ed21e6052b9e6", 744 | "IPY_MODEL_d49975885373480ca3833d58bc95a3f9", 745 | "IPY_MODEL_9ab24683ddb345e6abb1b8b7b75076ea" 746 | ], 747 | "layout": "IPY_MODEL_d90e979343e94ba9a7907b3400de4aa1" 748 | } 749 | }, 750 | "89b38d09df7b4bc3b49ed21e6052b9e6": { 751 | "model_module": "@jupyter-widgets/controls", 752 | "model_name": "HTMLModel", 753 | "model_module_version": "1.5.0", 754 | "state": { 755 | "_dom_classes": [], 756 | "_model_module": "@jupyter-widgets/controls", 757 | "_model_module_version": "1.5.0", 758 | "_model_name": "HTMLModel", 759 | "_view_count": null, 760 | "_view_module": "@jupyter-widgets/controls", 761 | "_view_module_version": "1.5.0", 762 | "_view_name": "HTMLView", 763 | "description": "", 764 | "description_tooltip": null, 765 | "layout": "IPY_MODEL_5da4e592e09f42cf8f2039b10b527285", 766 | "placeholder": "​", 767 | "style": "IPY_MODEL_946b3920189d43b5955184572bd23cbf", 768 | "value": "Epoch 999: 100%" 769 | } 770 | }, 771 | "d49975885373480ca3833d58bc95a3f9": { 772 | "model_module": "@jupyter-widgets/controls", 773 | "model_name": "FloatProgressModel", 774 | "model_module_version": "1.5.0", 775 | "state": { 776 | "_dom_classes": [], 777 | "_model_module": "@jupyter-widgets/controls", 778 | "_model_module_version": "1.5.0", 779 | "_model_name": "FloatProgressModel", 780 | "_view_count": null, 781 | "_view_module": "@jupyter-widgets/controls", 782 | "_view_module_version": "1.5.0", 783 | "_view_name": "ProgressView", 784 | "bar_style": "success", 785 | "description": "", 786 | "description_tooltip": null, 787 | "layout": "IPY_MODEL_e147f1498b9d462dadcc82392cec7d30", 788 | "max": 1, 789 | "min": 0, 790 | "orientation": "horizontal", 791 | "style": "IPY_MODEL_9f50cdd69138403181e4a188bcc1fe8d", 792 | "value": 1 793 | } 794 | }, 795 | "9ab24683ddb345e6abb1b8b7b75076ea": { 796 | "model_module": "@jupyter-widgets/controls", 797 | "model_name": "HTMLModel", 798 | "model_module_version": "1.5.0", 799 | "state": { 800 | "_dom_classes": [], 801 | "_model_module": "@jupyter-widgets/controls", 802 | "_model_module_version": "1.5.0", 803 | "_model_name": "HTMLModel", 804 | "_view_count": null, 805 | "_view_module": "@jupyter-widgets/controls", 806 | "_view_module_version": "1.5.0", 807 | "_view_name": "HTMLView", 808 | "description": "", 809 | "description_tooltip": null, 810 | "layout": "IPY_MODEL_aa6cd9138fed4f47a86d23e1c6879740", 811 | "placeholder": "​", 812 | "style": "IPY_MODEL_293f31185a6c4fe8a68a577b2c39fdeb", 813 | "value": " 1/1 [00:00<00:00, 9.03it/s, v_num=9]" 814 | } 815 | }, 816 | "d90e979343e94ba9a7907b3400de4aa1": { 817 | "model_module": "@jupyter-widgets/base", 818 | "model_name": "LayoutModel", 819 | "model_module_version": "1.2.0", 820 | "state": { 821 | "_model_module": "@jupyter-widgets/base", 822 | "_model_module_version": "1.2.0", 823 | "_model_name": "LayoutModel", 824 | "_view_count": null, 825 | "_view_module": "@jupyter-widgets/base", 826 | "_view_module_version": "1.2.0", 827 | "_view_name": "LayoutView", 828 | "align_content": null, 829 | "align_items": null, 830 | "align_self": null, 831 | "border": null, 832 | "bottom": null, 833 | "display": "inline-flex", 834 | "flex": null, 835 | "flex_flow": "row wrap", 836 | "grid_area": null, 837 | "grid_auto_columns": null, 838 | "grid_auto_flow": null, 839 | "grid_auto_rows": null, 840 | "grid_column": null, 841 | "grid_gap": null, 842 | "grid_row": null, 843 | "grid_template_areas": null, 844 | "grid_template_columns": null, 845 | "grid_template_rows": null, 846 | "height": null, 847 | "justify_content": null, 848 | "justify_items": null, 849 | "left": null, 850 | "margin": null, 851 | "max_height": null, 852 | "max_width": null, 853 | "min_height": null, 854 | "min_width": null, 855 | "object_fit": null, 856 | "object_position": null, 857 | "order": null, 858 | "overflow": null, 859 | "overflow_x": null, 860 | "overflow_y": null, 861 | "padding": null, 862 | "right": null, 863 | "top": null, 864 | "visibility": null, 865 | "width": "100%" 866 | } 867 | }, 868 | "5da4e592e09f42cf8f2039b10b527285": { 869 | "model_module": "@jupyter-widgets/base", 870 | "model_name": "LayoutModel", 871 | "model_module_version": "1.2.0", 872 | "state": { 873 | "_model_module": "@jupyter-widgets/base", 874 | "_model_module_version": "1.2.0", 875 | "_model_name": "LayoutModel", 876 | "_view_count": null, 877 | "_view_module": "@jupyter-widgets/base", 878 | "_view_module_version": "1.2.0", 879 | "_view_name": "LayoutView", 880 | "align_content": null, 881 | "align_items": null, 882 | "align_self": null, 883 | "border": null, 884 | "bottom": null, 885 | "display": null, 886 | "flex": null, 887 | "flex_flow": null, 888 | "grid_area": null, 889 | "grid_auto_columns": null, 890 | "grid_auto_flow": null, 891 | "grid_auto_rows": null, 892 | "grid_column": null, 893 | "grid_gap": null, 894 | "grid_row": null, 895 | "grid_template_areas": null, 896 | "grid_template_columns": null, 897 | "grid_template_rows": null, 898 | "height": null, 899 | "justify_content": null, 900 | "justify_items": null, 901 | "left": null, 902 | "margin": null, 903 | "max_height": null, 904 | "max_width": null, 905 | "min_height": null, 906 | "min_width": null, 907 | "object_fit": null, 908 | "object_position": null, 909 | "order": null, 910 | "overflow": null, 911 | "overflow_x": null, 912 | "overflow_y": null, 913 | "padding": null, 914 | "right": null, 915 | "top": null, 916 | "visibility": null, 917 | "width": null 918 | } 919 | }, 920 | "946b3920189d43b5955184572bd23cbf": { 921 | "model_module": "@jupyter-widgets/controls", 922 | "model_name": "DescriptionStyleModel", 923 | "model_module_version": "1.5.0", 924 | "state": { 925 | "_model_module": "@jupyter-widgets/controls", 926 | "_model_module_version": "1.5.0", 927 | "_model_name": "DescriptionStyleModel", 928 | "_view_count": null, 929 | "_view_module": "@jupyter-widgets/base", 930 | "_view_module_version": "1.2.0", 931 | "_view_name": "StyleView", 932 | "description_width": "" 933 | } 934 | }, 935 | "e147f1498b9d462dadcc82392cec7d30": { 936 | "model_module": "@jupyter-widgets/base", 937 | "model_name": "LayoutModel", 938 | "model_module_version": "1.2.0", 939 | "state": { 940 | "_model_module": "@jupyter-widgets/base", 941 | "_model_module_version": "1.2.0", 942 | "_model_name": "LayoutModel", 943 | "_view_count": null, 944 | "_view_module": "@jupyter-widgets/base", 945 | "_view_module_version": "1.2.0", 946 | "_view_name": "LayoutView", 947 | "align_content": null, 948 | "align_items": null, 949 | "align_self": null, 950 | "border": null, 951 | "bottom": null, 952 | "display": null, 953 | "flex": "2", 954 | "flex_flow": null, 955 | "grid_area": null, 956 | "grid_auto_columns": null, 957 | "grid_auto_flow": null, 958 | "grid_auto_rows": null, 959 | "grid_column": null, 960 | "grid_gap": null, 961 | "grid_row": null, 962 | "grid_template_areas": null, 963 | "grid_template_columns": null, 964 | "grid_template_rows": null, 965 | "height": null, 966 | "justify_content": null, 967 | "justify_items": null, 968 | "left": null, 969 | "margin": null, 970 | "max_height": null, 971 | "max_width": null, 972 | "min_height": null, 973 | "min_width": null, 974 | "object_fit": null, 975 | "object_position": null, 976 | "order": null, 977 | "overflow": null, 978 | "overflow_x": null, 979 | "overflow_y": null, 980 | "padding": null, 981 | "right": null, 982 | "top": null, 983 | "visibility": null, 984 | "width": null 985 | } 986 | }, 987 | "9f50cdd69138403181e4a188bcc1fe8d": { 988 | "model_module": "@jupyter-widgets/controls", 989 | "model_name": "ProgressStyleModel", 990 | "model_module_version": "1.5.0", 991 | "state": { 992 | "_model_module": "@jupyter-widgets/controls", 993 | "_model_module_version": "1.5.0", 994 | "_model_name": "ProgressStyleModel", 995 | "_view_count": null, 996 | "_view_module": "@jupyter-widgets/base", 997 | "_view_module_version": "1.2.0", 998 | "_view_name": "StyleView", 999 | "bar_color": null, 1000 | "description_width": "" 1001 | } 1002 | }, 1003 | "aa6cd9138fed4f47a86d23e1c6879740": { 1004 | "model_module": "@jupyter-widgets/base", 1005 | "model_name": "LayoutModel", 1006 | "model_module_version": "1.2.0", 1007 | "state": { 1008 | "_model_module": "@jupyter-widgets/base", 1009 | "_model_module_version": "1.2.0", 1010 | "_model_name": "LayoutModel", 1011 | "_view_count": null, 1012 | "_view_module": "@jupyter-widgets/base", 1013 | "_view_module_version": "1.2.0", 1014 | "_view_name": "LayoutView", 1015 | "align_content": null, 1016 | "align_items": null, 1017 | "align_self": null, 1018 | "border": null, 1019 | "bottom": null, 1020 | "display": null, 1021 | "flex": null, 1022 | "flex_flow": null, 1023 | "grid_area": null, 1024 | "grid_auto_columns": null, 1025 | "grid_auto_flow": null, 1026 | "grid_auto_rows": null, 1027 | "grid_column": null, 1028 | "grid_gap": null, 1029 | "grid_row": null, 1030 | "grid_template_areas": null, 1031 | "grid_template_columns": null, 1032 | "grid_template_rows": null, 1033 | "height": null, 1034 | "justify_content": null, 1035 | "justify_items": null, 1036 | "left": null, 1037 | "margin": null, 1038 | "max_height": null, 1039 | "max_width": null, 1040 | "min_height": null, 1041 | "min_width": null, 1042 | "object_fit": null, 1043 | "object_position": null, 1044 | "order": null, 1045 | "overflow": null, 1046 | "overflow_x": null, 1047 | "overflow_y": null, 1048 | "padding": null, 1049 | "right": null, 1050 | "top": null, 1051 | "visibility": null, 1052 | "width": null 1053 | } 1054 | }, 1055 | "293f31185a6c4fe8a68a577b2c39fdeb": { 1056 | "model_module": "@jupyter-widgets/controls", 1057 | "model_name": "DescriptionStyleModel", 1058 | "model_module_version": "1.5.0", 1059 | "state": { 1060 | "_model_module": "@jupyter-widgets/controls", 1061 | "_model_module_version": "1.5.0", 1062 | "_model_name": "DescriptionStyleModel", 1063 | "_view_count": null, 1064 | "_view_module": "@jupyter-widgets/base", 1065 | "_view_module_version": "1.2.0", 1066 | "_view_name": "StyleView", 1067 | "description_width": "" 1068 | } 1069 | }, 1070 | "c71c99fce4484f258c79d866ec4c9d0e": { 1071 | "model_module": "@jupyter-widgets/controls", 1072 | "model_name": "HBoxModel", 1073 | "model_module_version": "1.5.0", 1074 | "state": { 1075 | "_dom_classes": [], 1076 | "_model_module": "@jupyter-widgets/controls", 1077 | "_model_module_version": "1.5.0", 1078 | "_model_name": "HBoxModel", 1079 | "_view_count": null, 1080 | "_view_module": "@jupyter-widgets/controls", 1081 | "_view_module_version": "1.5.0", 1082 | "_view_name": "HBoxView", 1083 | "box_style": "", 1084 | "children": [ 1085 | "IPY_MODEL_b4cac92462074a1395355da3e211a4bc", 1086 | "IPY_MODEL_beb1a70370214e219ce35561adb8f3a7", 1087 | "IPY_MODEL_90b19e6e0d184bc5a40abb337007b767" 1088 | ], 1089 | "layout": "IPY_MODEL_4a00b44c331146f48ea8b032ee5f5d9e" 1090 | } 1091 | }, 1092 | "b4cac92462074a1395355da3e211a4bc": { 1093 | "model_module": "@jupyter-widgets/controls", 1094 | "model_name": "HTMLModel", 1095 | "model_module_version": "1.5.0", 1096 | "state": { 1097 | "_dom_classes": [], 1098 | "_model_module": "@jupyter-widgets/controls", 1099 | "_model_module_version": "1.5.0", 1100 | "_model_name": "HTMLModel", 1101 | "_view_count": null, 1102 | "_view_module": "@jupyter-widgets/controls", 1103 | "_view_module_version": "1.5.0", 1104 | "_view_name": "HTMLView", 1105 | "description": "", 1106 | "description_tooltip": null, 1107 | "layout": "IPY_MODEL_faabf22fe42841c0b9769f901ff571ad", 1108 | "placeholder": "​", 1109 | "style": "IPY_MODEL_86d5601def684b9691b8e4167f39b252", 1110 | "value": "Epoch 999: 100%" 1111 | } 1112 | }, 1113 | "beb1a70370214e219ce35561adb8f3a7": { 1114 | "model_module": "@jupyter-widgets/controls", 1115 | "model_name": "FloatProgressModel", 1116 | "model_module_version": "1.5.0", 1117 | "state": { 1118 | "_dom_classes": [], 1119 | "_model_module": "@jupyter-widgets/controls", 1120 | "_model_module_version": "1.5.0", 1121 | "_model_name": "FloatProgressModel", 1122 | "_view_count": null, 1123 | "_view_module": "@jupyter-widgets/controls", 1124 | "_view_module_version": "1.5.0", 1125 | "_view_name": "ProgressView", 1126 | "bar_style": "success", 1127 | "description": "", 1128 | "description_tooltip": null, 1129 | "layout": "IPY_MODEL_02d1c82b972e4614a38df08036617218", 1130 | "max": 1, 1131 | "min": 0, 1132 | "orientation": "horizontal", 1133 | "style": "IPY_MODEL_5af51a8b676e46efb3c0d6dbe271a42a", 1134 | "value": 1 1135 | } 1136 | }, 1137 | "90b19e6e0d184bc5a40abb337007b767": { 1138 | "model_module": "@jupyter-widgets/controls", 1139 | "model_name": "HTMLModel", 1140 | "model_module_version": "1.5.0", 1141 | "state": { 1142 | "_dom_classes": [], 1143 | "_model_module": "@jupyter-widgets/controls", 1144 | "_model_module_version": "1.5.0", 1145 | "_model_name": "HTMLModel", 1146 | "_view_count": null, 1147 | "_view_module": "@jupyter-widgets/controls", 1148 | "_view_module_version": "1.5.0", 1149 | "_view_name": "HTMLView", 1150 | "description": "", 1151 | "description_tooltip": null, 1152 | "layout": "IPY_MODEL_d879a22ef90d4545a5f65922e92ce50d", 1153 | "placeholder": "​", 1154 | "style": "IPY_MODEL_106fb922b3d94535974c3aac26436a06", 1155 | "value": " 1/1 [00:00<00:00, 10.90it/s, v_num=10]" 1156 | } 1157 | }, 1158 | "4a00b44c331146f48ea8b032ee5f5d9e": { 1159 | "model_module": "@jupyter-widgets/base", 1160 | "model_name": "LayoutModel", 1161 | "model_module_version": "1.2.0", 1162 | "state": { 1163 | "_model_module": "@jupyter-widgets/base", 1164 | "_model_module_version": "1.2.0", 1165 | "_model_name": "LayoutModel", 1166 | "_view_count": null, 1167 | "_view_module": "@jupyter-widgets/base", 1168 | "_view_module_version": "1.2.0", 1169 | "_view_name": "LayoutView", 1170 | "align_content": null, 1171 | "align_items": null, 1172 | "align_self": null, 1173 | "border": null, 1174 | "bottom": null, 1175 | "display": "inline-flex", 1176 | "flex": null, 1177 | "flex_flow": "row wrap", 1178 | "grid_area": null, 1179 | "grid_auto_columns": null, 1180 | "grid_auto_flow": null, 1181 | "grid_auto_rows": null, 1182 | "grid_column": null, 1183 | "grid_gap": null, 1184 | "grid_row": null, 1185 | "grid_template_areas": null, 1186 | "grid_template_columns": null, 1187 | "grid_template_rows": null, 1188 | "height": null, 1189 | "justify_content": null, 1190 | "justify_items": null, 1191 | "left": null, 1192 | "margin": null, 1193 | "max_height": null, 1194 | "max_width": null, 1195 | "min_height": null, 1196 | "min_width": null, 1197 | "object_fit": null, 1198 | "object_position": null, 1199 | "order": null, 1200 | "overflow": null, 1201 | "overflow_x": null, 1202 | "overflow_y": null, 1203 | "padding": null, 1204 | "right": null, 1205 | "top": null, 1206 | "visibility": null, 1207 | "width": "100%" 1208 | } 1209 | }, 1210 | "faabf22fe42841c0b9769f901ff571ad": { 1211 | "model_module": "@jupyter-widgets/base", 1212 | "model_name": "LayoutModel", 1213 | "model_module_version": "1.2.0", 1214 | "state": { 1215 | "_model_module": "@jupyter-widgets/base", 1216 | "_model_module_version": "1.2.0", 1217 | "_model_name": "LayoutModel", 1218 | "_view_count": null, 1219 | "_view_module": "@jupyter-widgets/base", 1220 | "_view_module_version": "1.2.0", 1221 | "_view_name": "LayoutView", 1222 | "align_content": null, 1223 | "align_items": null, 1224 | "align_self": null, 1225 | "border": null, 1226 | "bottom": null, 1227 | "display": null, 1228 | "flex": null, 1229 | "flex_flow": null, 1230 | "grid_area": null, 1231 | "grid_auto_columns": null, 1232 | "grid_auto_flow": null, 1233 | "grid_auto_rows": null, 1234 | "grid_column": null, 1235 | "grid_gap": null, 1236 | "grid_row": null, 1237 | "grid_template_areas": null, 1238 | "grid_template_columns": null, 1239 | "grid_template_rows": null, 1240 | "height": null, 1241 | "justify_content": null, 1242 | "justify_items": null, 1243 | "left": null, 1244 | "margin": null, 1245 | "max_height": null, 1246 | "max_width": null, 1247 | "min_height": null, 1248 | "min_width": null, 1249 | "object_fit": null, 1250 | "object_position": null, 1251 | "order": null, 1252 | "overflow": null, 1253 | "overflow_x": null, 1254 | "overflow_y": null, 1255 | "padding": null, 1256 | "right": null, 1257 | "top": null, 1258 | "visibility": null, 1259 | "width": null 1260 | } 1261 | }, 1262 | "86d5601def684b9691b8e4167f39b252": { 1263 | "model_module": "@jupyter-widgets/controls", 1264 | "model_name": "DescriptionStyleModel", 1265 | "model_module_version": "1.5.0", 1266 | "state": { 1267 | "_model_module": "@jupyter-widgets/controls", 1268 | "_model_module_version": "1.5.0", 1269 | "_model_name": "DescriptionStyleModel", 1270 | "_view_count": null, 1271 | "_view_module": "@jupyter-widgets/base", 1272 | "_view_module_version": "1.2.0", 1273 | "_view_name": "StyleView", 1274 | "description_width": "" 1275 | } 1276 | }, 1277 | "02d1c82b972e4614a38df08036617218": { 1278 | "model_module": "@jupyter-widgets/base", 1279 | "model_name": "LayoutModel", 1280 | "model_module_version": "1.2.0", 1281 | "state": { 1282 | "_model_module": "@jupyter-widgets/base", 1283 | "_model_module_version": "1.2.0", 1284 | "_model_name": "LayoutModel", 1285 | "_view_count": null, 1286 | "_view_module": "@jupyter-widgets/base", 1287 | "_view_module_version": "1.2.0", 1288 | "_view_name": "LayoutView", 1289 | "align_content": null, 1290 | "align_items": null, 1291 | "align_self": null, 1292 | "border": null, 1293 | "bottom": null, 1294 | "display": null, 1295 | "flex": "2", 1296 | "flex_flow": null, 1297 | "grid_area": null, 1298 | "grid_auto_columns": null, 1299 | "grid_auto_flow": null, 1300 | "grid_auto_rows": null, 1301 | "grid_column": null, 1302 | "grid_gap": null, 1303 | "grid_row": null, 1304 | "grid_template_areas": null, 1305 | "grid_template_columns": null, 1306 | "grid_template_rows": null, 1307 | "height": null, 1308 | "justify_content": null, 1309 | "justify_items": null, 1310 | "left": null, 1311 | "margin": null, 1312 | "max_height": null, 1313 | "max_width": null, 1314 | "min_height": null, 1315 | "min_width": null, 1316 | "object_fit": null, 1317 | "object_position": null, 1318 | "order": null, 1319 | "overflow": null, 1320 | "overflow_x": null, 1321 | "overflow_y": null, 1322 | "padding": null, 1323 | "right": null, 1324 | "top": null, 1325 | "visibility": null, 1326 | "width": null 1327 | } 1328 | }, 1329 | "5af51a8b676e46efb3c0d6dbe271a42a": { 1330 | "model_module": "@jupyter-widgets/controls", 1331 | "model_name": "ProgressStyleModel", 1332 | "model_module_version": "1.5.0", 1333 | "state": { 1334 | "_model_module": "@jupyter-widgets/controls", 1335 | "_model_module_version": "1.5.0", 1336 | "_model_name": "ProgressStyleModel", 1337 | "_view_count": null, 1338 | "_view_module": "@jupyter-widgets/base", 1339 | "_view_module_version": "1.2.0", 1340 | "_view_name": "StyleView", 1341 | "bar_color": null, 1342 | "description_width": "" 1343 | } 1344 | }, 1345 | "d879a22ef90d4545a5f65922e92ce50d": { 1346 | "model_module": "@jupyter-widgets/base", 1347 | "model_name": "LayoutModel", 1348 | "model_module_version": "1.2.0", 1349 | "state": { 1350 | "_model_module": "@jupyter-widgets/base", 1351 | "_model_module_version": "1.2.0", 1352 | "_model_name": "LayoutModel", 1353 | "_view_count": null, 1354 | "_view_module": "@jupyter-widgets/base", 1355 | "_view_module_version": "1.2.0", 1356 | "_view_name": "LayoutView", 1357 | "align_content": null, 1358 | "align_items": null, 1359 | "align_self": null, 1360 | "border": null, 1361 | "bottom": null, 1362 | "display": null, 1363 | "flex": null, 1364 | "flex_flow": null, 1365 | "grid_area": null, 1366 | "grid_auto_columns": null, 1367 | "grid_auto_flow": null, 1368 | "grid_auto_rows": null, 1369 | "grid_column": null, 1370 | "grid_gap": null, 1371 | "grid_row": null, 1372 | "grid_template_areas": null, 1373 | "grid_template_columns": null, 1374 | "grid_template_rows": null, 1375 | "height": null, 1376 | "justify_content": null, 1377 | "justify_items": null, 1378 | "left": null, 1379 | "margin": null, 1380 | "max_height": null, 1381 | "max_width": null, 1382 | "min_height": null, 1383 | "min_width": null, 1384 | "object_fit": null, 1385 | "object_position": null, 1386 | "order": null, 1387 | "overflow": null, 1388 | "overflow_x": null, 1389 | "overflow_y": null, 1390 | "padding": null, 1391 | "right": null, 1392 | "top": null, 1393 | "visibility": null, 1394 | "width": null 1395 | } 1396 | }, 1397 | "106fb922b3d94535974c3aac26436a06": { 1398 | "model_module": "@jupyter-widgets/controls", 1399 | "model_name": "DescriptionStyleModel", 1400 | "model_module_version": "1.5.0", 1401 | "state": { 1402 | "_model_module": "@jupyter-widgets/controls", 1403 | "_model_module_version": "1.5.0", 1404 | "_model_name": "DescriptionStyleModel", 1405 | "_view_count": null, 1406 | "_view_module": "@jupyter-widgets/base", 1407 | "_view_module_version": "1.2.0", 1408 | "_view_name": "StyleView", 1409 | "description_width": "" 1410 | } 1411 | } 1412 | } 1413 | } 1414 | }, 1415 | "nbformat": 4, 1416 | "nbformat_minor": 0 1417 | } -------------------------------------------------------------------------------- /src/alg_1.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.nn import functional as F 4 | 5 | class TokenEmbedding(nn.Module): 6 | def __init__(self, vocab_size, embed_dim, **kwargs): 7 | super().__init__() 8 | self.vocab_size = vocab_size # number of tokens 9 | self.embed_dim = embed_dim # embedding dimension 10 | self.emb = nn.Embedding(self.vocab_size, self.embed_dim, **kwargs) 11 | 12 | def forward(self, idx): 13 | return self.emb(idx) 14 | 15 | 16 | 17 | if __name__ == "__main__": 18 | vocab_size = 10000 19 | embed_dim = 50 20 | token_emb = TokenEmbedding(vocab_size, embed_dim) 21 | vocab_size=100 22 | batch_size = 32 23 | seq_len=128 24 | idx = torch.randint(0,vocab_size, size = (batch_size, seq_len)) 25 | idx_embeding = token_emb(idx) 26 | print(idx_embeding.shape) 27 | -------------------------------------------------------------------------------- /src/alg_10.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.nn import functional as F 4 | 5 | from src.alg_2 import PositionEmbedding 6 | from src.alg_1 import TokenEmbedding 7 | from src.alg_5 import MHAttentionInefficient 8 | from src.alg_6 import LayerNorm 9 | from src.alg_7 import TokenUnembedding 10 | 11 | 12 | class DTransformer(nn.Module): 13 | 14 | class ResNet(nn.Module): 15 | def __init__(self, module): 16 | super().__init__() 17 | self.module = module 18 | 19 | def forward(self, x, y=None, mask=None): 20 | if (y is not None) and (mask is None): 21 | return self.module(x,y) + x 22 | elif (y is not None) and (mask is not None): 23 | return self.module(x,y, mask) + x 24 | else: 25 | return self.module(x) + x 26 | 27 | 28 | def __init__(self, embed_dim, mlp_dim, max_seq_len, L_dec, vocab_size, num_heads): 29 | super().__init__() 30 | self.vocab_size = vocab_size 31 | self.embed_dim = embed_dim 32 | self.max_seq_len = max_seq_len 33 | self.token_emb = TokenEmbedding(vocab_size, embed_dim) 34 | self.pos_emb = PositionEmbedding(max_seq_len, embed_dim) 35 | # setup decoder 36 | self.decoder_layers = nn.ModuleList() 37 | for i in range(L_dec): 38 | self.decoder_layers.add_module(f"dec_layer_norm1_{i}", LayerNorm(embed_dim)) 39 | multi_head_attention = MHAttentionInefficient(num_heads=num_heads, 40 | input_dim=embed_dim, 41 | atten_dim=embed_dim, 42 | output_dim=embed_dim) 43 | self.decoder_layers.add_module(f"dec_attention_layer_{i}",DTransformer.ResNet(multi_head_attention)) 44 | self.decoder_layers.add_module(f"dec_layer_norm2_{i}", LayerNorm(embed_dim)) 45 | mlp = nn.Sequential( 46 | nn.Linear(embed_dim, mlp_dim), 47 | nn.GELU(), 48 | nn.Linear(mlp_dim, embed_dim) 49 | ) 50 | self.decoder_layers.add_module(f"dec_mlp_layer_{i}", DTransformer.ResNet(mlp)) 51 | 52 | self.layer_norm = LayerNorm(embed_dim) 53 | self.unembed = TokenUnembedding(vocab_size, embed_dim) 54 | self.register_buffer("mask", torch.tril(torch.ones(self.max_seq_len, self.max_seq_len)) 55 | .view(1, self.max_seq_len, self.max_seq_len)) 56 | 57 | 58 | def forward(self, x): 59 | lx = x.size()[1] #max seq len 60 | x = self.token_emb(x) + self.pos_emb(lx)[None,:,:] 61 | for name, layer in self.decoder_layers.named_children(): 62 | if "dec_attention_layer_" in name: 63 | x = layer(x,x, self.mask.masked_fill(self.mask==0, float('-inf'))) 64 | else: 65 | x = layer(x) 66 | x = self.layer_norm(x) 67 | return self.unembed(x) 68 | 69 | 70 | if __name__ == "__main__": 71 | max_seq_len = 512 72 | embed_dim = 50 73 | vocab_size = 10000 74 | ed_seq2seq = DTransformer(embed_dim=embed_dim, mlp_dim=32, max_seq_len=max_seq_len, 75 | L_dec=3, vocab_size=vocab_size, num_heads=3) 76 | 77 | 78 | bs = 32 79 | x_ids = torch.randint(0,vocab_size, size = (bs, max_seq_len)) 80 | print(ed_seq2seq(x_ids).size()) 81 | -------------------------------------------------------------------------------- /src/alg_2.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.nn import functional as F 4 | 5 | class PositionEmbedding(nn.Module): 6 | def __init__(self, lmax, embed_dim, **kwargs): 7 | super().__init__() 8 | self.lmax = lmax # position of a token in the sequence 9 | self.embed_dim = embed_dim # embedding dimension 10 | self.emb = nn.Parameter(torch.zeros(self.lmax, self.embed_dim)) 11 | 12 | def forward(self, t): 13 | return self.emb[:t,:] 14 | 15 | 16 | 17 | 18 | if __name__ == "__main__": 19 | max_seq_len = 512 20 | embed_dim = 50 21 | vocab_size=100 22 | batch_size = 32 23 | idx = torch.randint(0,vocab_size, size = (batch_size, max_seq_len)) 24 | pos_emb = PositionEmbedding(max_seq_len, embed_dim) 25 | position_embeddings = pos_emb(max_seq_len) 26 | print(position_embeddings.shape) 27 | -------------------------------------------------------------------------------- /src/alg_3.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.nn import functional as F 4 | from src.alg_2 import PositionEmbedding 5 | from src.alg_1 import TokenEmbedding 6 | import numpy as np 7 | 8 | 9 | class SingleQueryAttention(nn.Module): 10 | def __init__(self, input_dim, atten_dim, output_dim, **kwargs): 11 | super().__init__() 12 | self.input_dim = input_dim 13 | self.atten_dim = atten_dim 14 | self.output_dim = output_dim 15 | self.key = nn.Linear(input_dim, atten_dim) 16 | self.query = nn.Linear(input_dim, atten_dim) 17 | self.value = nn.Linear(input_dim, output_dim) 18 | 19 | def forward(self, current_token, context_tokens): 20 | q = self.query(current_token) 21 | k = self.key(context_tokens) 22 | v = self.value(context_tokens) 23 | 24 | att = torch.einsum('ijk,ilk->ilk', [q,k]) / np.sqrt(self.atten_dim) 25 | att = F.softmax(att, dim=-1) 26 | v = torch.einsum('ijk,ijk->ik', [att, v]) 27 | return v[:,None,:] 28 | 29 | def main(): 30 | max_seq_len = 512 31 | embed_dim = 50 32 | vocab_size = 10000 33 | token_emb = TokenEmbedding(vocab_size, embed_dim) 34 | pos_emb = PositionEmbedding(max_seq_len, embed_dim) 35 | attention = SingleQueryAttention(embed_dim, embed_dim, embed_dim) 36 | 37 | batch_size=32 38 | current_token = torch.randint(0,vocab_size, size = (batch_size, 1)) 39 | context_tokens = torch.randint(0,vocab_size, size = (batch_size, max_seq_len)) 40 | current_token_embeddings = token_emb(current_token) + pos_emb(1) 41 | context_token_embeddings = token_emb(context_tokens) + pos_emb(max_seq_len) 42 | x_emb = attention(current_token_embeddings,context_token_embeddings) 43 | print(x_emb.shape) 44 | 45 | 46 | if __name__ == "__main__": 47 | main() -------------------------------------------------------------------------------- /src/alg_4.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.nn import functional as F 4 | import math 5 | 6 | from src.alg_2 import PositionEmbedding 7 | from src.alg_1 import TokenEmbedding 8 | from src.alg_3 import SingleQueryAttention 9 | 10 | 11 | class Attention(SingleQueryAttention): 12 | ## TODO : clean up attention mask specification API 13 | def __init__(self, input_dim, atten_dim, output_dim, **kwargs): 14 | super().__init__(input_dim, atten_dim, output_dim, **kwargs) 15 | # self.register_buffer("masked_bias", torch.tensor(-1e4)) 16 | 17 | def forward(self, current_tokens, context_tokens, attention_mask=None): 18 | q = self.query(current_tokens) 19 | k = self.key(context_tokens) 20 | v = self.value(context_tokens) 21 | att = torch.einsum('ijk,ilk->ijl', [q,k]) / math.sqrt(self.atten_dim) 22 | # hugging face implementation: https://tinyurl.com/22f9b6y 23 | if attention_mask is not None: 24 | # Apply the attention mask 25 | # att = torch.where(mask, att, self.masked_bias) 26 | att = att + attention_mask 27 | 28 | att = F.softmax(att, dim=-1) 29 | v = torch.einsum('ijk,ilm->ilk',[v, att]) 30 | return v 31 | 32 | def main(): 33 | max_seq_len = 512 34 | embed_dim = 50 35 | vocab_size = 10000 36 | token_emb = TokenEmbedding(vocab_size, embed_dim) 37 | pos_emb = PositionEmbedding(max_seq_len, embed_dim) 38 | attention = Attention(embed_dim, embed_dim, embed_dim) 39 | batch_size=32 40 | idx = torch.randint(0,vocab_size, size = (batch_size, max_seq_len)) 41 | idz = torch.randint(0,vocab_size, size = (batch_size, max_seq_len//2)) 42 | x = token_emb(idx) + pos_emb(max_seq_len) # current token representations 43 | z = token_emb(idz) + pos_emb(max_seq_len//2) # context token reps. 44 | print(f"x shape: {x.shape}") 45 | print(f"z shape: {z.shape}") 46 | mask = torch.tril(torch.ones(max_seq_len, max_seq_len//2)) 47 | # updated representation of x folding in information from z 48 | x_emb = attention(x,z, mask.masked_fill(mask==0, float('-inf'))) 49 | 50 | print(x_emb.shape) 51 | 52 | 53 | if __name__ == "__main__": 54 | main() -------------------------------------------------------------------------------- /src/alg_5.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.nn import functional as F 4 | import math 5 | import unittest 6 | 7 | from src.alg_2 import PositionEmbedding 8 | from src.alg_1 import TokenEmbedding 9 | from src.alg_4 import Attention 10 | 11 | class MHAttentionInefficient(nn.Module): 12 | ''' 13 | An inefficient, but faithful, implementation of multi-headed attention 14 | ''' 15 | def __init__(self, num_heads, input_dim, atten_dim, output_dim, **kwargs): 16 | super().__init__() 17 | self.heads =nn.ModuleList() 18 | for _ in range(num_heads): 19 | self.heads.append(Attention(input_dim, atten_dim, output_dim, **kwargs)) 20 | self.out_proj = nn.Linear(num_heads*output_dim, output_dim) 21 | self.num_heads = num_heads 22 | self.output_dim = output_dim 23 | 24 | 25 | def forward(self,current_tokens, context_tokens, attention_mask=None): 26 | ys = [] 27 | for head in self.heads: 28 | ys.append(head(current_tokens, context_tokens, attention_mask)) 29 | y = self.out_proj(torch.cat(ys,axis=2)) 30 | 31 | return y -------------------------------------------------------------------------------- /src/alg_6.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.nn import functional as F 4 | 5 | class LayerNorm(nn.Module): 6 | def __init__(self, dim): 7 | super().__init__() 8 | self.dim = dim # embedding dimension 9 | self.scale = nn.Parameter(torch.ones(self.dim)) 10 | self.offset = nn.Parameter(torch.zeros(self.dim)) 11 | 12 | def forward(self, x): 13 | m = x.mean() 14 | s = x.std() 15 | x_hat = ((x - m)/s)*self.scale + self.offset 16 | return x_hat 17 | 18 | 19 | if __name__ == "__main__": 20 | dim = 50 21 | activations = torch.randn(dim) 22 | layer_norm = LayerNorm(dim) 23 | activations_hat = layer_norm(activations) 24 | print(activations_hat.shape) -------------------------------------------------------------------------------- /src/alg_7.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.nn import functional as F 4 | 5 | class TokenUnembedding(nn.Module): 6 | def __init__(self, vocab_size, embed_dim, **kwargs): 7 | super().__init__() 8 | self.vocab_size = vocab_size # number of tokens 9 | self.embed_dim = embed_dim # embedding dimension 10 | self.un_emb = nn.Linear(self.embed_dim, self.vocab_size, **kwargs) 11 | self.softmax = nn.Softmax(dim=1) 12 | 13 | def forward(self, embedding): 14 | logits = self.un_emb(embedding) 15 | return self.softmax(logits) 16 | 17 | 18 | 19 | 20 | if __name__ == "__main__": 21 | vocab_size = 10000 22 | embed_dim = 50 23 | token_unemb = TokenUnembedding(vocab_size, embed_dim) 24 | emb = torch.randn(32, embed_dim) 25 | unemb = token_unemb(emb) 26 | print(unemb.shape) 27 | -------------------------------------------------------------------------------- /src/alg_8.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.nn import functional as F 4 | 5 | from src.alg_2 import PositionEmbedding 6 | from src.alg_1 import TokenEmbedding 7 | from src.alg_5 import MHAttentionInefficient 8 | from src.alg_6 import LayerNorm 9 | from src.alg_7 import TokenUnembedding 10 | 11 | 12 | class EDTransformer(nn.Module): 13 | 14 | class ResNet(nn.Module): 15 | def __init__(self, module): 16 | super().__init__() 17 | self.module = module 18 | 19 | def forward(self, x, y=None, mask=None): 20 | if (y is not None) and (mask is None): 21 | return self.module(x,y) + x 22 | elif (y is not None) and (mask is not None): 23 | return self.module(x,y, mask) + x 24 | else: 25 | return self.module(x) + x 26 | 27 | 28 | def __init__(self, embed_dim, mlp_dim, max_seq_len, L_enc, L_dec, vocab_size, num_heads): 29 | super().__init__() 30 | self.vocab_size = vocab_size 31 | self.embed_dim = embed_dim 32 | self.max_seq_len = max_seq_len 33 | self.token_emb = TokenEmbedding(vocab_size, embed_dim) 34 | self.pos_emb = PositionEmbedding(max_seq_len, embed_dim) 35 | # setup encoder 36 | self.encoder_layers = nn.ModuleList() 37 | for i in range(L_enc): 38 | multi_head_attention = MHAttentionInefficient(num_heads=num_heads, 39 | input_dim=embed_dim, 40 | atten_dim=embed_dim, 41 | output_dim=embed_dim) 42 | self.encoder_layers.add_module(f"enc_attention_layer_{i}",EDTransformer.ResNet(multi_head_attention)) 43 | layer_norm_1 = LayerNorm(embed_dim) 44 | self.encoder_layers.add_module(f"enc_layer_norm1_layer_{i}", layer_norm_1) 45 | mlp = nn.Sequential( 46 | nn.Linear(embed_dim, mlp_dim), 47 | nn.ReLU(), 48 | nn.Linear(mlp_dim, embed_dim) 49 | ) 50 | self.encoder_layers.add_module(f"enc_mlp_layer_{i}", EDTransformer.ResNet(mlp)) 51 | layer_norm_2 = LayerNorm(embed_dim) 52 | self.encoder_layers.add_module(f"enc_layer_norm2_layer_{i}", layer_norm_2) 53 | # setup decoder 54 | self.decoder_layers = nn.ModuleList() 55 | for i in range(L_dec): 56 | multi_head_attention = MHAttentionInefficient(num_heads=num_heads, 57 | input_dim=embed_dim, 58 | atten_dim=embed_dim, 59 | output_dim=embed_dim) 60 | self.decoder_layers.add_module(f"dec_attention1_layer_{i}",EDTransformer.ResNet(multi_head_attention)) 61 | layer_norm_3 = LayerNorm(embed_dim) 62 | self.decoder_layers.add_module(f"dec_layer_norm1_{i}", layer_norm_3) 63 | multi_head_attention = MHAttentionInefficient(num_heads=num_heads, 64 | input_dim=embed_dim, 65 | atten_dim=embed_dim, 66 | output_dim=embed_dim) 67 | self.decoder_layers.add_module(f"dec_attention2_layer_{i}", EDTransformer.ResNet(multi_head_attention)) 68 | layer_norm_4 = LayerNorm(embed_dim) 69 | self.decoder_layers.add_module(f"dec_layer_norm2_layer_{i}", layer_norm_4) 70 | mlp = nn.Sequential( 71 | nn.Linear(embed_dim, mlp_dim), 72 | nn.ReLU(), 73 | nn.Linear(mlp_dim, embed_dim) 74 | ) 75 | self.decoder_layers.add_module(f"dec_mlp_layer_{i}", EDTransformer.ResNet(mlp)) 76 | layer_norm_5 = LayerNorm(embed_dim) 77 | self.decoder_layers.add_module(f"dec_layer_norm3_layer_{i}", layer_norm_5) 78 | 79 | self.unembed = TokenUnembedding(vocab_size, embed_dim) 80 | self.register_buffer("mask", torch.tril(torch.ones(self.max_seq_len, self.max_seq_len)) 81 | .view(1, self.max_seq_len, self.max_seq_len)) 82 | 83 | 84 | def forward(self, z, x=None): 85 | lz = z.size()[1] #max seq len 86 | z = self.token_emb(z) + self.pos_emb(lz)[None,:,:] 87 | for name, layer in self.encoder_layers.named_children(): 88 | if "attention" in name: 89 | z = layer(z, z) 90 | else: 91 | z = layer(z) 92 | lx = x.size()[1] #max seq len 93 | x = self.token_emb(x) + self.pos_emb(lx)[None,:,:] 94 | for name, layer in self.decoder_layers.named_children(): 95 | if "dec_attention1" in name: 96 | x = layer(x,x, self.mask.masked_fill(self.mask==0, float('-inf'))) 97 | elif "dec_attention2" in name: 98 | x = layer(x,z) 99 | else: 100 | x = layer(x) 101 | return self.unembed(x) 102 | 103 | 104 | if __name__ == "__main__": 105 | max_seq_len = 512 106 | embed_dim = 50 107 | vocab_size = 10000 108 | ed_seq2seq = EDTransformer(embed_dim=embed_dim, mlp_dim=32, max_seq_len=max_seq_len, 109 | L_dec=3, L_enc=3, vocab_size=vocab_size, num_heads=3) 110 | 111 | bs = 32 112 | z_ids = torch.randint(0,vocab_size, size = (bs*2, max_seq_len)) 113 | x_ids = torch.randint(0,vocab_size, size = (bs*2, 1)) 114 | output = ed_seq2seq(z_ids, x_ids) 115 | print(output.size()) 116 | probs = output.gather(dim=2, index=x_ids.unsqueeze(-1)).squeeze(-1) 117 | print(probs.size()) -------------------------------------------------------------------------------- /src/alg_9.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.nn import functional as F 4 | 5 | from alg_2 import PositionEmbedding 6 | from alg_1 import TokenEmbedding 7 | from alg_5 import MHAttentionInefficient 8 | from alg_6 import LayerNorm 9 | from alg_7 import TokenUnembedding 10 | 11 | class ETransformer(nn.Module): 12 | 13 | class ResNet(nn.Module): 14 | def __init__(self, module): 15 | super().__init__() 16 | self.module = module 17 | 18 | def forward(self, x, y=None, mask=None): 19 | if (y is not None) and (mask is None): 20 | return self.module(x,y) + x 21 | elif (y is not None) and (mask is not None): 22 | return self.module(x,y, mask) + x 23 | else: 24 | return self.module(x) + x 25 | 26 | def __init__(self, embed_dim, mlp_dim, output_dim, max_seq_len, L_enc, vocab_size, num_heads): 27 | super().__init__() 28 | self.vocab_size = vocab_size 29 | self.embed_dim = embed_dim 30 | self.max_seq_len = max_seq_len 31 | self.token_emb = TokenEmbedding(vocab_size, embed_dim) 32 | self.pos_emb = PositionEmbedding(max_seq_len, embed_dim) 33 | # setup encoder 34 | self.encoder_layers = nn.ModuleList() 35 | for i in range(L_enc): 36 | multi_head_attention = MHAttentionInefficient(num_heads=num_heads, 37 | input_dim=embed_dim, 38 | atten_dim=embed_dim, 39 | output_dim=embed_dim) 40 | self.encoder_layers.add_module(f"enc_attention_layer_{i}",ETransformer.ResNet(multi_head_attention)) 41 | layer_norm_1 = LayerNorm(embed_dim) 42 | self.encoder_layers.add_module(f"enc_layer_norm1_layer_{i}", layer_norm_1) 43 | mlp = nn.Sequential( 44 | nn.Linear(embed_dim, mlp_dim), 45 | nn.GELU(), 46 | nn.Linear(mlp_dim, embed_dim) 47 | ) 48 | self.encoder_layers.add_module(f"enc_mlp_layer_{i}", ETransformer.ResNet(mlp)) 49 | layer_norm_2 = LayerNorm(embed_dim) 50 | self.encoder_layers.add_module(f"enc_layer_norm2_layer_{i}", layer_norm_2) 51 | self.fc = nn.Linear(embed_dim, output_dim) 52 | self.gelu = nn.GELU() 53 | self.output_layer_norm = LayerNorm(output_dim) 54 | self.unembed = TokenUnembedding(vocab_size, output_dim) 55 | 56 | def forward(self, x): 57 | lx = x.size()[1] #max seq len 58 | x = self.token_emb(x) + self.pos_emb(lx)[None,:,:] 59 | for name, layer in self.encoder_layers.named_children(): 60 | if "attention" in name: 61 | x = layer(x, x) 62 | else: 63 | x = layer(x) 64 | x = self.gelu(self.fc(x)) 65 | x = self.output_layer_norm(x) 66 | return self.unembed(x) 67 | 68 | if __name__ == "__main__": 69 | max_seq_len = 512 70 | embed_dim = 50 71 | vocab_size = 10000 72 | encoder_transformer = ETransformer(embed_dim=embed_dim, mlp_dim=32, 73 | max_seq_len=max_seq_len, 74 | L_enc=3, vocab_size=vocab_size, 75 | output_dim=16,num_heads=3) 76 | 77 | bs = 32 78 | x_ids = torch.randint(0,vocab_size, size = (bs, max_seq_len)) 79 | print(encoder_transformer(x_ids)) -------------------------------------------------------------------------------- /tests/test_alg5.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | import torch 3 | from src.alg_2 import PositionEmbedding 4 | from src.alg_1 import TokenEmbedding 5 | from src.alg_5 import MHAttentionInefficient 6 | 7 | 8 | class TestMHAttentionInefficient(unittest.TestCase): 9 | def setUp(self): 10 | # initialize objects that will be used in the tests 11 | self.max_seq_len = 512 12 | self.embed_dim = 50 13 | self.vocab_size = 10000 14 | self.token_emb = TokenEmbedding(vocab_size=self.vocab_size, embed_dim=self.embed_dim) 15 | self.pos_emb = PositionEmbedding(self.max_seq_len, self.embed_dim) 16 | self.attention = MHAttentionInefficient(num_heads=3, input_dim=self.embed_dim, 17 | atten_dim=self.embed_dim, output_dim=self.embed_dim) 18 | 19 | def test_shape(self): 20 | # test that the output shape is as expected 21 | batch_size=32 22 | idx = torch.randint(0,self.vocab_size, size = (batch_size, self.max_seq_len)) 23 | idz = torch.randint(0,self.vocab_size, size = (batch_size, self.max_seq_len//2)) 24 | x = self.token_emb(idx) + self.pos_emb(self.max_seq_len) # current token representations 25 | z = self.token_emb(idz) + self.pos_emb(self.max_seq_len//2) # context token reps. 26 | y = self.attention(x,z) 27 | self.assertEqual(y.shape, (batch_size, self.max_seq_len, self.embed_dim)) 28 | 29 | 30 | if __name__ == "__main__": 31 | unittest.main() --------------------------------------------------------------------------------