├── .gitignore ├── LICENSE ├── README.md ├── archive └── exploration.ipynb ├── conf └── config.yaml ├── model.py ├── requirements.txt └── train.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | 131 | .vscode 132 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 spencerbraun 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Anomaly Transformer in PyTorch 2 | 3 | This is an implementation of [Anomaly Transformer: Time Series Anomaly Detection with Association Discrepancy](https://arxiv.org/abs/2110.02642). This paper has been accepted as a [Spotlight Paper at ICLR 2022](https://openreview.net/forum?id=LzQQ89U1qm_). 4 | 5 | Repository currently a work in progress. 6 | 7 | ## Usage 8 | 9 | ### Requirements 10 | 11 | Install dependences into a virtualenv: 12 | 13 | ```bash 14 | $ python -m venv env 15 | $ source env/bin/activate 16 | (env) $ pip install -r requirements.txt 17 | ``` 18 | 19 | Written with python version `3.8.11` 20 | 21 | ### Data and Configuration 22 | 23 | Custom datasets can be placed in the `data/` dir. Edits should be made to the `conf/data/default.yaml` file to reflect the correct properties of the data. All other configuration hyperparameters can be set in the hydra configs. 24 | 25 | ### Train 26 | 27 | Once properly configured, a model can be trained via `python train.py`. 28 | 29 | ## Citations 30 | 31 | ```bibtex 32 | @misc{xu2021anomaly, 33 | title={Anomaly Transformer: Time Series Anomaly Detection with Association Discrepancy}, 34 | author={Jiehui Xu and Haixu Wu and Jianmin Wang and Mingsheng Long}, 35 | year={2021}, 36 | eprint={2110.02642}, 37 | archivePrefix={arXiv}, 38 | primaryClass={cs.LG} 39 | } 40 | ``` 41 | -------------------------------------------------------------------------------- /archive/exploration.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [] 9 | }, 10 | { 11 | "cell_type": "code", 12 | "execution_count": 1, 13 | "metadata": {}, 14 | "outputs": [], 15 | "source": [ 16 | "import torch\n", 17 | "import torch.nn as nn\n", 18 | "import torch.nn.functional as F\n", 19 | "\n", 20 | "import hydra\n", 21 | "\n", 22 | "from omegaconf import DictConfig\n", 23 | "from omegaconf.omegaconf import OmegaConf\n", 24 | "\n", 25 | "import numpy as np\n", 26 | "from sklearn.datasets import load_digits\n", 27 | "\n", 28 | "import matplotlib.pyplot as plt\n", 29 | "\n" 30 | ] 31 | }, 32 | { 33 | "cell_type": "markdown", 34 | "metadata": {}, 35 | "source": [] 36 | }, 37 | { 38 | "cell_type": "code", 39 | "execution_count": 2, 40 | "metadata": {}, 41 | "outputs": [], 42 | "source": [ 43 | "time_series = np.random.rand(1000, 300)\n", 44 | "time_series[500:560, 100:200] += 0.3\n", 45 | "time_series = torch.from_numpy(time_series)" 46 | ] 47 | }, 48 | { 49 | "cell_type": "code", 50 | "execution_count": 10, 51 | "metadata": {}, 52 | "outputs": [], 53 | "source": [ 54 | "p = torch.from_numpy(np.abs(np.indices((100,100))[0] - np.indices((100,100))[1]))\n", 55 | "sigma = torch.ones(100).view(100, 1) * 2" 56 | ] 57 | }, 58 | { 59 | "cell_type": "code", 60 | "execution_count": 12, 61 | "metadata": {}, 62 | "outputs": [], 63 | "source": [ 64 | "P = torch.ones(10,10) * torch.arange(10).view(10,1)" 65 | ] 66 | }, 67 | { 68 | "cell_type": "code", 69 | "execution_count": 14, 70 | "metadata": {}, 71 | "outputs": [], 72 | "source": [ 73 | "S = torch.ones(10,10) * torch.arange(10).view(1,10)" 74 | ] 75 | }, 76 | { 77 | "cell_type": "code", 78 | "execution_count": 18, 79 | "metadata": {}, 80 | "outputs": [ 81 | { 82 | "name": "stderr", 83 | "output_type": "stream", 84 | "text": [ 85 | "/Users/spencerbraun/.pyenv/versions/3.8.11/envs/dl/lib/python3.8/site-packages/torch/nn/functional.py:2747: UserWarning: reduction: 'mean' divides the total loss by both the batch size and the support size.'batchmean' divides only by the batch size, and aligns with the KL div math definition.'mean' will be changed to behave the same as 'batchmean' in the next major release.\n", 86 | " warnings.warn(\n" 87 | ] 88 | }, 89 | { 90 | "data": { 91 | "text/plain": [ 92 | "tensor(3.4057)" 93 | ] 94 | }, 95 | "execution_count": 18, 96 | "metadata": {}, 97 | "output_type": "execute_result" 98 | } 99 | ], 100 | "source": [ 101 | "lambda row: F.kl_div(P[row,:], S[row,:]) + F.kl_div(S[row,:], P[row,:])" 102 | ] 103 | }, 104 | { 105 | "cell_type": "code", 106 | "execution_count": 20, 107 | "metadata": {}, 108 | "outputs": [ 109 | { 110 | "data": { 111 | "text/plain": [ 112 | "tensor([[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n", 113 | " [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],\n", 114 | " [2., 2., 2., 2., 2., 2., 2., 2., 2., 2.],\n", 115 | " [3., 3., 3., 3., 3., 3., 3., 3., 3., 3.],\n", 116 | " [4., 4., 4., 4., 4., 4., 4., 4., 4., 4.],\n", 117 | " [5., 5., 5., 5., 5., 5., 5., 5., 5., 5.],\n", 118 | " [6., 6., 6., 6., 6., 6., 6., 6., 6., 6.],\n", 119 | " [7., 7., 7., 7., 7., 7., 7., 7., 7., 7.],\n", 120 | " [8., 8., 8., 8., 8., 8., 8., 8., 8., 8.],\n", 121 | " [9., 9., 9., 9., 9., 9., 9., 9., 9., 9.]])" 122 | ] 123 | }, 124 | "execution_count": 20, 125 | "metadata": {}, 126 | "output_type": "execute_result" 127 | } 128 | ], 129 | "source": [ 130 | "P" 131 | ] 132 | }, 133 | { 134 | "cell_type": "code", 135 | "execution_count": 3, 136 | "metadata": {}, 137 | "outputs": [ 138 | { 139 | "data": { 140 | "text/plain": [ 141 | "tensor([0.0000, 0.1000, 0.2000, 0.3000, 0.4000, 0.5000, 0.6000, 0.7000, 0.8000,\n", 142 | " 0.9000, 1.0000, 1.1000, 1.2000, 1.3000, 1.4000, 1.5000, 1.6000, 1.7000,\n", 143 | " 1.8000, 1.9000, 2.0000, 2.1000, 2.2000, 2.3000, 2.4000, 2.5000, 2.6000,\n", 144 | " 2.7000, 2.8000, 2.9000, 3.0000, 3.1000, 3.2000, 3.3000, 3.4000, 3.5000,\n", 145 | " 3.6000, 3.7000, 3.8000, 3.9000, 4.0000, 4.1000, 4.2000, 4.3000, 4.4000,\n", 146 | " 4.5000, 4.6000, 4.7000, 4.8000, 4.9000, 5.0000, 5.1000, 5.2000, 5.3000,\n", 147 | " 5.4000, 5.5000, 5.6000, 5.7000, 5.8000, 5.9000, 6.0000, 6.1000, 6.2000,\n", 148 | " 6.3000, 6.4000, 6.5000, 6.6000, 6.7000, 6.8000, 6.9000, 7.0000, 7.1000,\n", 149 | " 7.2000, 7.3000, 7.4000, 7.5000, 7.6000, 7.7000, 7.8000, 7.9000, 8.0000,\n", 150 | " 8.1000, 8.2000, 8.3000, 8.4000, 8.5000, 8.6000, 8.7000, 8.8000, 8.9000,\n", 151 | " 9.0000, 9.1000, 9.2000, 9.3000, 9.4000, 9.5000, 9.6000, 9.7000, 9.8000,\n", 152 | " 9.9000])" 153 | ] 154 | }, 155 | "execution_count": 3, 156 | "metadata": {}, 157 | "output_type": "execute_result" 158 | } 159 | ], 160 | "source": [ 161 | "torch.arange(0, 10, 0.1)" 162 | ] 163 | }, 164 | { 165 | "cell_type": "code", 166 | "execution_count": 9, 167 | "metadata": {}, 168 | "outputs": [ 169 | { 170 | "ename": "NameError", 171 | "evalue": "name 'sigma' is not defined", 172 | "output_type": "error", 173 | "traceback": [ 174 | "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", 175 | "\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)", 176 | "\u001b[0;32m/var/folders/8w/r6kg1v9x7bbfzf9dw9gjslc80000gn/T/ipykernel_95641/3924366364.py\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mexp\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpow\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m2\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m/\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0msigma\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", 177 | "\u001b[0;31mNameError\u001b[0m: name 'sigma' is not defined" 178 | ] 179 | } 180 | ], 181 | "source": [ 182 | "torch.exp(p.pow(2) / (sigma))" 183 | ] 184 | }, 185 | { 186 | "cell_type": "code", 187 | "execution_count": 6, 188 | "metadata": {}, 189 | "outputs": [ 190 | { 191 | "ename": "NameError", 192 | "evalue": "name 'p' is not defined", 193 | "output_type": "error", 194 | "traceback": [ 195 | "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", 196 | "\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)", 197 | "\u001b[0;32m/var/folders/8w/r6kg1v9x7bbfzf9dw9gjslc80000gn/T/ipykernel_95641/3546984218.py\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0mgaussian\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mnormal\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfloat\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0msigma\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 2\u001b[0m \u001b[0mgaussian\u001b[0m \u001b[0;34m/=\u001b[0m \u001b[0mgaussian\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msum\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdim\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m-\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mview\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m-\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m1\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", 198 | "\u001b[0;31mNameError\u001b[0m: name 'p' is not defined" 199 | ] 200 | } 201 | ], 202 | "source": [ 203 | "gaussian = torch.normal(p.float(), sigma)\n", 204 | "gaussian /= gaussian.sum(dim=-1).view(-1, 1)" 205 | ] 206 | }, 207 | { 208 | "cell_type": "code", 209 | "execution_count": 5, 210 | "metadata": {}, 211 | "outputs": [ 212 | { 213 | "ename": "NameError", 214 | "evalue": "name 'gaussian' is not defined", 215 | "output_type": "error", 216 | "traceback": [ 217 | "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", 218 | "\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)", 219 | "\u001b[0;32m/var/folders/8w/r6kg1v9x7bbfzf9dw9gjslc80000gn/T/ipykernel_95641/494294515.py\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0mgaussian\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msum\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", 220 | "\u001b[0;31mNameError\u001b[0m: name 'gaussian' is not defined" 221 | ] 222 | } 223 | ], 224 | "source": [ 225 | "gaussian[0,:].sum()" 226 | ] 227 | }, 228 | { 229 | "cell_type": "code", 230 | "execution_count": 98, 231 | "metadata": {}, 232 | "outputs": [ 233 | { 234 | "data": { 235 | "text/plain": [ 236 | "tensor([[0.1000, 0.1000, 0.1000, 0.1000, 0.1000, 0.1000, 0.1000, 0.1000, 0.1000,\n", 237 | " 0.1000],\n", 238 | " [0.1000, 0.1000, 0.1000, 0.1000, 0.1000, 0.1000, 0.1000, 0.1000, 0.1000,\n", 239 | " 0.1000],\n", 240 | " [0.1000, 0.1000, 0.1000, 0.1000, 0.1000, 0.1000, 0.1000, 0.1000, 0.1000,\n", 241 | " 0.1000],\n", 242 | " [0.1000, 0.1000, 0.1000, 0.1000, 0.1000, 0.1000, 0.1000, 0.1000, 0.1000,\n", 243 | " 0.1000],\n", 244 | " [0.1000, 0.1000, 0.1000, 0.1000, 0.1000, 0.1000, 0.1000, 0.1000, 0.1000,\n", 245 | " 0.1000]])" 246 | ] 247 | }, 248 | "execution_count": 98, 249 | "metadata": {}, 250 | "output_type": "execute_result" 251 | } 252 | ], 253 | "source": [ 254 | "torch.ones(5,10)/ torch.ones(5,10).sum(dim=-1).view(-1,1)" 255 | ] 256 | }, 257 | { 258 | "cell_type": "code", 259 | "execution_count": 85, 260 | "metadata": {}, 261 | "outputs": [ 262 | { 263 | "data": { 264 | "text/plain": [ 265 | "tensor([10., 10., 10., 10., 10.])" 266 | ] 267 | }, 268 | "execution_count": 85, 269 | "metadata": {}, 270 | "output_type": "execute_result" 271 | } 272 | ], 273 | "source": [ 274 | "torch.ones(5,10).sum(dim=-1)" 275 | ] 276 | }, 277 | { 278 | "cell_type": "code", 279 | "execution_count": 67, 280 | "metadata": {}, 281 | "outputs": [], 282 | "source": [] 283 | }, 284 | { 285 | "cell_type": "code", 286 | "execution_count": 72, 287 | "metadata": {}, 288 | "outputs": [], 289 | "source": [] 290 | }, 291 | { 292 | "cell_type": "code", 293 | "execution_count": 54, 294 | "metadata": {}, 295 | "outputs": [ 296 | { 297 | "data": { 298 | "text/plain": [ 299 | "tensor([3.6336e-02, 7.9632e-01, 1.7809e+00, 3.8046e+00, 2.3193e+00, 4.2680e+00,\n", 300 | " 8.3718e+00, 7.3727e+00, 6.7060e+00, 1.1051e+01, 1.0537e+01, 1.0064e+01,\n", 301 | " 1.1236e+01, 1.1610e+01, 1.4704e+01, 1.5483e+01, 1.6960e+01, 1.7324e+01,\n", 302 | " 1.8780e+01, 1.9655e+01, 1.9641e+01, 2.2539e+01, 2.2296e+01, 2.1461e+01,\n", 303 | " 2.3893e+01, 2.5538e+01, 2.7476e+01, 2.7799e+01, 2.8607e+01, 3.0070e+01,\n", 304 | " 2.9921e+01, 3.0479e+01, 3.2850e+01, 3.2870e+01, 3.5183e+01, 3.4523e+01,\n", 305 | " 3.5027e+01, 3.6092e+01, 3.7668e+01, 3.8510e+01, 4.0233e+01, 4.1136e+01,\n", 306 | " 4.1330e+01, 4.4310e+01, 4.5134e+01, 4.6767e+01, 4.7380e+01, 4.6682e+01,\n", 307 | " 4.8102e+01, 4.8557e+01, 5.0157e+01, 5.1855e+01, 5.2537e+01, 5.3827e+01,\n", 308 | " 5.4526e+01, 5.5917e+01, 5.6168e+01, 5.6827e+01, 5.9424e+01, 5.6755e+01,\n", 309 | " 5.9634e+01, 6.0564e+01, 5.9774e+01, 6.2367e+01, 6.1566e+01, 6.5457e+01,\n", 310 | " 6.7245e+01, 6.8038e+01, 6.8918e+01, 7.0426e+01, 7.0787e+01, 7.1730e+01,\n", 311 | " 7.2374e+01, 7.3203e+01, 7.5236e+01, 7.4870e+01, 7.5451e+01, 7.7490e+01,\n", 312 | " 7.8921e+01, 7.8873e+01, 7.8570e+01, 8.1434e+01, 8.0754e+01, 8.3289e+01,\n", 313 | " 8.3406e+01, 8.6066e+01, 8.5195e+01, 8.5312e+01, 8.7704e+01, 8.8209e+01,\n", 314 | " 9.1053e+01, 9.0458e+01, 9.1672e+01, 9.2758e+01, 9.5681e+01, 9.5560e+01,\n", 315 | " 9.7419e+01, 9.4693e+01, 9.6847e+01, 9.7406e+01])" 316 | ] 317 | }, 318 | "execution_count": 54, 319 | "metadata": {}, 320 | "output_type": "execute_result" 321 | } 322 | ], 323 | "source": [ 324 | "torch.normal(torch.arange(0,100).float())" 325 | ] 326 | }, 327 | { 328 | "cell_type": "code", 329 | "execution_count": 150, 330 | "metadata": {}, 331 | "outputs": [], 332 | "source": [ 333 | "class AnomalyAttention(nn.Module):\n", 334 | " def __init__(self, seq_dim, in_channels, out_channels):\n", 335 | " super(AnomalyAttention, self).__init__()\n", 336 | " self.W = nn.Linear(in_channels, out_channels, bias=False)\n", 337 | " self.Q = self.K = self.V = self.sigma = torch.zeros((seq_dim, out_channels))\n", 338 | " self.d_model = out_channels\n", 339 | " self.n = seq_dim\n", 340 | " self.P = torch.zeros((seq_dim, seq_dim))\n", 341 | " self.S = torch.zeros((seq_dim, seq_dim))\n", 342 | "\n", 343 | " def forward(self, x):\n", 344 | "\n", 345 | " self.initialize(x) # does this make sense?\n", 346 | " self.P = self.prior_association()\n", 347 | " self.S = self.series_association()\n", 348 | " Z = self.reconstruction()\n", 349 | "\n", 350 | " return Z\n", 351 | "\n", 352 | " def initialize(self, x):\n", 353 | " # self.d_model = x.shape[-1]\n", 354 | " self.Q = self.K = self.V = self.sigma = self.W(x)\n", 355 | " \n", 356 | "\n", 357 | " def prior_association(self):\n", 358 | " p = torch.from_numpy(\n", 359 | " np.abs(\n", 360 | " np.indices((self.n,self.n))[0] - \n", 361 | " np.indices((self.n,self.n))[1]\n", 362 | " )\n", 363 | " )\n", 364 | " gaussian = torch.normal(p.float(), self.sigma[:,0].abs())\n", 365 | " gaussian /= gaussian.sum(dim=-1).view(-1, 1)\n", 366 | "\n", 367 | " return gaussian\n", 368 | "\n", 369 | " def series_association(self):\n", 370 | " return F.softmax((self.Q @ self.K.T) / math.sqrt(self.d_model), dim=0)\n", 371 | "\n", 372 | " def reconstruction(self):\n", 373 | " return self.S @ self.V\n", 374 | "\n", 375 | " def association_discrepancy(self):\n", 376 | " return F.kl_div(self.P, self.S) + F.kl_div(self.S, self.P) #not going to be correct dimensions\n", 377 | "\n" 378 | ] 379 | }, 380 | { 381 | "cell_type": "code", 382 | "execution_count": 151, 383 | "metadata": {}, 384 | "outputs": [], 385 | "source": [ 386 | "class AnomalyTransformerBlock(nn.Module):\n", 387 | " def __init__(self, seq_dim, feat_dim):\n", 388 | " super().__init__()\n", 389 | " self.seq_dim, self.feat_dim = seq_dim, feat_dim\n", 390 | " \n", 391 | " self.attention = AnomalyAttention(self.seq_dim, self.feat_dim, self.feat_dim)\n", 392 | " self.ln1 = nn.LayerNorm(self.feat_dim)\n", 393 | " self.ff = nn.Sequential(\n", 394 | " nn.Linear(self.feat_dim, self.feat_dim),\n", 395 | " nn.ReLU()\n", 396 | " )\n", 397 | " self.ln2 = nn.LayerNorm(self.feat_dim)\n", 398 | " self.association_discrepancy = None\n", 399 | "\n", 400 | " def forward(self, x):\n", 401 | " x_identity = x \n", 402 | " x = self.attention(x)\n", 403 | " z = self.ln1(x + x_identity)\n", 404 | " \n", 405 | " z_identity = z\n", 406 | " z = self.ff(z)\n", 407 | " z = self.ln2(z + z_identity)\n", 408 | "\n", 409 | " self.association_discrepancy = self.attention.association_discrepancy().detach()\n", 410 | " \n", 411 | " return z" 412 | ] 413 | }, 414 | { 415 | "cell_type": "code", 416 | "execution_count": 166, 417 | "metadata": {}, 418 | "outputs": [], 419 | "source": [ 420 | "class AnomalyTransformer(nn.Module):\n", 421 | " def __init__(self, seqs, in_channels, layers, lambda_):\n", 422 | " super().__init__()\n", 423 | " self.blocks = nn.ModuleList([\n", 424 | " AnomalyTransformerBlock(seqs, in_channels) for _ in range(layers)\n", 425 | " ])\n", 426 | " self.output = None\n", 427 | " self.lambda_ = lambda_\n", 428 | " self.assoc_discrepancy = torch.zeros((seqs, len(self.blocks)))\n", 429 | " \n", 430 | " def forward(self, x):\n", 431 | " for idx, block in enumerate(self.blocks):\n", 432 | " x = block(x)\n", 433 | " self.assoc_discrepancy[:, idx] = block.association_discrepancy\n", 434 | " \n", 435 | " self.assoc_discrepancy = self.assoc_discrepancy.sum(dim=1) #N x 1\n", 436 | " self.output = x\n", 437 | " return x\n", 438 | "\n", 439 | " def loss(self, x):\n", 440 | " l2_norm = torch.linalg.matrix_norm(self.output - x, ord=2)\n", 441 | " return l2_norm + (self.lambda_ * self.assoc_discrepancy.mean())\n", 442 | "\n", 443 | " def anomaly_score(self, x):\n", 444 | " score = F.softmax(-self.assoc_discrepancy, dim=0)" 445 | ] 446 | }, 447 | { 448 | "cell_type": "code", 449 | "execution_count": 167, 450 | "metadata": {}, 451 | "outputs": [], 452 | "source": [ 453 | "model = AnomalyTransformer(seqs=1000, in_channels=300, layers=3, lambda_=0.1)" 454 | ] 455 | }, 456 | { 457 | "cell_type": "code", 458 | "execution_count": 168, 459 | "metadata": {}, 460 | "outputs": [ 461 | { 462 | "name": "stderr", 463 | "output_type": "stream", 464 | "text": [ 465 | "/Users/spencerbraun/.pyenv/versions/3.8.11/envs/dl/lib/python3.8/site-packages/torch/nn/functional.py:2747: UserWarning: reduction: 'mean' divides the total loss by both the batch size and the support size.'batchmean' divides only by the batch size, and aligns with the KL div math definition.'mean' will be changed to behave the same as 'batchmean' in the next major release.\n", 466 | " warnings.warn(\n" 467 | ] 468 | }, 469 | { 470 | "data": { 471 | "text/plain": [ 472 | "tensor([[-2.2100, 0.2870, -0.2157, ..., -1.2898, -1.4303, -0.3504],\n", 473 | " [-1.1387, -0.4830, -0.9881, ..., -1.1120, -2.4669, -0.5230],\n", 474 | " [-0.3306, -0.7025, -1.5827, ..., -0.3628, -1.8427, 0.7217],\n", 475 | " ...,\n", 476 | " [-0.5868, -0.5519, -1.9108, ..., -0.4716, -2.8175, -1.0170],\n", 477 | " [-1.6752, -0.7690, -2.3892, ..., -1.6920, -2.6238, 0.9201],\n", 478 | " [-1.0995, -0.0956, -0.5864, ..., -2.5304, -2.2143, -0.5381]],\n", 479 | " grad_fn=)" 480 | ] 481 | }, 482 | "execution_count": 168, 483 | "metadata": {}, 484 | "output_type": "execute_result" 485 | } 486 | ], 487 | "source": [ 488 | "model(time_series.float())" 489 | ] 490 | }, 491 | { 492 | "cell_type": "code", 493 | "execution_count": 170, 494 | "metadata": {}, 495 | "outputs": [ 496 | { 497 | "data": { 498 | "text/plain": [ 499 | "tensor(508.2613, dtype=torch.float64, grad_fn=)" 500 | ] 501 | }, 502 | "execution_count": 170, 503 | "metadata": {}, 504 | "output_type": "execute_result" 505 | } 506 | ], 507 | "source": [ 508 | "model.loss(time_series)" 509 | ] 510 | }, 511 | { 512 | "cell_type": "code", 513 | "execution_count": null, 514 | "metadata": {}, 515 | "outputs": [], 516 | "source": [] 517 | }, 518 | { 519 | "cell_type": "code", 520 | "execution_count": null, 521 | "metadata": {}, 522 | "outputs": [], 523 | "source": [] 524 | } 525 | ], 526 | "metadata": { 527 | "interpreter": { 528 | "hash": "f28f51a40c57d9bec6a3c7c5ba9f41b1bc9273fe115e19bff919e2ad7386eeca" 529 | }, 530 | "kernelspec": { 531 | "display_name": "Python 3.8.11 64-bit ('dl': pyenv)", 532 | "name": "python3" 533 | }, 534 | "language_info": { 535 | "codemirror_mode": { 536 | "name": "ipython", 537 | "version": 3 538 | }, 539 | "file_extension": ".py", 540 | "mimetype": "text/x-python", 541 | "name": "python", 542 | "nbconvert_exporter": "python", 543 | "pygments_lexer": "ipython3", 544 | "version": "3.8.11" 545 | }, 546 | "orig_nbformat": 4 547 | }, 548 | "nbformat": 4, 549 | "nbformat_minor": 2 550 | } 551 | -------------------------------------------------------------------------------- /conf/config.yaml: -------------------------------------------------------------------------------- 1 | 2 | seed: 0 3 | debug: False 4 | silent: False 5 | device: cuda 6 | 7 | max_iters: 1000000 8 | log_interval: 100 9 | val_interval: 5000 10 | model_save_pt: 5000 11 | 12 | lr: 1e-5 13 | batch_size: 32 14 | val_steps: 500 15 | grad_clip: 100. 16 | early_stop_patience: 20000 17 | early_stop_key: "loss/total_edit_val" 18 | dropout: 0.0 19 | results_dir: null 20 | 21 | eval_only: False 22 | half: False 23 | save: False 24 | 25 | model: 26 | pt: null 27 | 28 | data: 29 | path: null 30 | 31 | eval: 32 | verbose: True 33 | log_interval: 100 34 | final_eval: True 35 | 36 | hydra: 37 | run: 38 | dir: ./outputs/${now:%Y-%m-%d_%H-%M-%S_%f${uuid:}} 39 | sweep: 40 | dir: ./outputs/${now:%Y-%m-%d_%H-%M-%S_%f} 41 | subdir: ${hydra.job.num} -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | import numpy as np 8 | 9 | 10 | class AnomalyAttention(nn.Module): 11 | def __init__(self, N, d_model): 12 | super(AnomalyAttention, self).__init__() 13 | self.d_model = d_model 14 | self.N = N 15 | 16 | self.Wq = nn.Linear(d_model, d_model, bias=False) 17 | self.Wk = nn.Linear(d_model, d_model, bias=False) 18 | self.Wv = nn.Linear(d_model, d_model, bias=False) 19 | self.Ws = nn.Linear(d_model, 1, bias=False) 20 | 21 | self.Q = self.K = self.V = self.sigma = torch.zeros((N, d_model)) 22 | 23 | self.P = torch.zeros((N, N)) 24 | self.S = torch.zeros((N, N)) 25 | 26 | def forward(self, x): 27 | 28 | self.initialize(x) 29 | self.P = self.prior_association() 30 | self.S = self.series_association() 31 | Z = self.reconstruction() 32 | 33 | return Z 34 | 35 | def initialize(self, x): 36 | self.Q = self.Wq(x) 37 | self.K = self.Wk(x) 38 | self.V = self.Wv(x) 39 | self.sigma = self.Ws(x) 40 | 41 | @staticmethod 42 | def gaussian_kernel(mean, sigma): 43 | normalize = 1 / (math.sqrt(2 * torch.pi) * sigma) 44 | return normalize * torch.exp(-0.5 * (mean / sigma).pow(2)) 45 | 46 | def prior_association(self): 47 | p = torch.from_numpy( 48 | np.abs(np.indices((self.N, self.N))[0] - np.indices((self.N, self.N))[1]) 49 | ) 50 | gaussian = self.gaussian_kernel(p.float(), self.sigma) 51 | gaussian /= gaussian.sum(dim=-1).view(-1, 1) 52 | 53 | return gaussian 54 | 55 | def series_association(self): 56 | return F.softmax((self.Q @ self.K.T) / math.sqrt(self.d_model), dim=0) 57 | 58 | def reconstruction(self): 59 | return self.S @ self.V 60 | 61 | 62 | class AnomalyTransformerBlock(nn.Module): 63 | def __init__(self, N, d_model): 64 | super().__init__() 65 | self.N, self.d_model = N, d_model 66 | 67 | self.attention = AnomalyAttention(self.N, self.d_model) 68 | self.ln1 = nn.LayerNorm(self.d_model) 69 | self.ff = nn.Sequential(nn.Linear(self.d_model, self.d_model), nn.ReLU()) 70 | self.ln2 = nn.LayerNorm(self.d_model) 71 | 72 | def forward(self, x): 73 | x_identity = x 74 | x = self.attention(x) 75 | z = self.ln1(x + x_identity) 76 | 77 | z_identity = z 78 | z = self.ff(z) 79 | z = self.ln2(z + z_identity) 80 | 81 | return z 82 | 83 | 84 | class AnomalyTransformer(nn.Module): 85 | def __init__(self, N, d_model, layers, lambda_): 86 | super().__init__() 87 | self.N = N 88 | self.d_model = d_model 89 | 90 | self.blocks = nn.ModuleList( 91 | [AnomalyTransformerBlock(self.N, self.d_model) for _ in range(layers)] 92 | ) 93 | self.output = None 94 | self.lambda_ = lambda_ 95 | 96 | self.P_layers = [] 97 | self.S_layers = [] 98 | 99 | def forward(self, x): 100 | for idx, block in enumerate(self.blocks): 101 | x = block(x) 102 | self.P_layers.append(block.attention.P) 103 | self.S_layers.append(block.attention.S) 104 | 105 | self.output = x 106 | return x 107 | 108 | def layer_association_discrepancy(self, Pl, Sl, x): 109 | rowwise_kl = lambda row: ( 110 | F.kl_div(Pl[row, :], Sl[row, :]) + F.kl_div(Sl[row, :], Pl[row, :]) 111 | ) 112 | ad_vector = torch.concat( 113 | [rowwise_kl(row).unsqueeze(0) for row in range(Pl.shape[0])] 114 | ) 115 | return ad_vector 116 | 117 | def association_discrepancy(self, P_list, S_list, x): 118 | 119 | return (1 / len(P_list)) * sum( 120 | [ 121 | self.layer_association_discrepancy(P, S, x) 122 | for P, S in zip(P_list, S_list) 123 | ] 124 | ) 125 | 126 | def loss_function(self, x_hat, P_list, S_list, lambda_, x): 127 | frob_norm = torch.linalg.matrix_norm(x_hat - x, ord="fro") 128 | return frob_norm - ( 129 | lambda_ 130 | * torch.linalg.norm(self.association_discrepancy(P_list, S_list, x), ord=1) 131 | ) 132 | 133 | def min_loss(self, x): 134 | P_list = self.P_layers 135 | S_list = [S.detach() for S in self.S_layers] 136 | lambda_ = -self.lambda_ 137 | return self.loss_function(self.output, P_list, S_list, lambda_, x) 138 | 139 | def max_loss(self, x): 140 | P_list = [P.detach() for P in self.P_layers] 141 | S_list = self.S_layers 142 | lambda_ = self.lambda_ 143 | return self.loss_function(self.output, P_list, S_list, lambda_, x) 144 | 145 | def anomaly_score(self, x): 146 | ad = F.softmax( 147 | -self.association_discrepancy(self.P_layers, self.S_layers, x), dim=0 148 | ) 149 | 150 | assert ad.shape[0] == self.N 151 | 152 | norm = torch.tensor( 153 | [ 154 | torch.linalg.norm(x[i, :] - self.output[i, :], ord=2) 155 | for i in range(self.N) 156 | ] 157 | ) 158 | 159 | assert norm.shape[0] == self.N 160 | 161 | score = torch.mul(ad, norm) 162 | 163 | return score 164 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | hydra-core==1.1.1 2 | numpy==1.21.3 3 | omegaconf==2.1.1 4 | torch==1.10.0 5 | tqdm==4.62.3 6 | transformers==4.11.3 7 | wandb==0.12.10 -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from datetime import datetime 3 | 4 | import numpy as np 5 | import torch 6 | import wandb 7 | from tqdm import tqdm 8 | from torch.utils.data import DataLoader 9 | 10 | import hydra 11 | from omegaconf import DictConfig 12 | from omegaconf.omegaconf import OmegaConf 13 | from transformers.optimization import AdamW, get_cosine_schedule_with_warmup 14 | 15 | from model import AnomalyTransformer 16 | 17 | logger = logging.getLogger(__name__) 18 | 19 | 20 | def train(config, model, train_data, val_data): 21 | 22 | train_dataloader = DataLoader( 23 | train_data, 24 | batch_size=config.train.batch_size, 25 | shuffle=config.train.shuffle, 26 | # collate_fn=collate_fn, 27 | drop_last=True, 28 | ) 29 | total_steps = int(len(train_dataloader) * config.train.epochs) 30 | warmup_steps = max(int(total_steps * config.train.warmup_ratio), 200) 31 | optimizer = AdamW( 32 | model.parameters(), 33 | lr=config.train.lr, 34 | eps=config.train.adam_epsilon, 35 | ) 36 | scheduler = get_cosine_schedule_with_warmup( 37 | optimizer, num_warmup_steps=warmup_steps, num_training_steps=total_steps 38 | ) 39 | print("Total steps: {}".format(total_steps)) 40 | print("Warmup steps: {}".format(warmup_steps)) 41 | 42 | num_steps = 0 43 | model.train() 44 | 45 | for epoch in range(int(config.train.epochs)): 46 | model.zero_grad() 47 | for step, batch in enumerate(tqdm(train_dataloader)): 48 | 49 | outputs = model(batch) 50 | min_loss = model.min_loss(batch) 51 | max_loss = model.max_loss(batch) 52 | min_loss.backward(retain_graph=True) 53 | max_loss.backward() 54 | 55 | torch.nn.utils.clip_grad_norm_( 56 | model.parameters(), config.train.max_grad_norm 57 | ) 58 | optimizer.step() 59 | scheduler.step() 60 | model.zero_grad() 61 | 62 | num_steps += 1 63 | 64 | if not config.debug: 65 | wandb.log({"loss": loss.item()}, step=num_steps) 66 | 67 | if not config.debug: 68 | wandb.log(output, step=num_steps) 69 | torch.save(model.state_dict(), config.train.pt) 70 | 71 | 72 | @hydra.main(config_path="./conf", config_name="config") 73 | def main(config: DictConfig) -> None: 74 | 75 | set_seed(config.train.state.seed) 76 | 77 | logger.info(OmegaConf.to_yaml(config, resolve=True)) 78 | logger.info(f"Using the model: {config.model.name}") 79 | 80 | train_data, val_data = get_data(config) 81 | config.data.num_class = len(set([x["labels"] for x in train_features])) 82 | print(f"num_class: {config.data.num_class}") 83 | 84 | if not config.debug: 85 | timestamp = datetime.now().strftime("%Y%m%d-%H%M%S") 86 | run_name = f"{config.train.wandb.run_name}_{config.model.model}_{config.data.name}_{timestamp}" 87 | wandb.init( 88 | entity=config.train.wandb_entity, 89 | project=config.train.wandb_project, 90 | config=dict(config), 91 | name=run_name, 92 | ) 93 | if not config.train.pt: 94 | config.train.pt = f"{config.train.pt}/{run_name}" 95 | 96 | model = AnomalyTransformer(config) 97 | model.to(config.device) 98 | 99 | train(config, model, train_data, val_data) 100 | 101 | 102 | if __name__ == "__main__": 103 | main() 104 | --------------------------------------------------------------------------------