├── Evaluation.ipynb ├── README.md ├── Run_AutoODE.ipynb ├── Run_DSL.ipynb └── ode_nn ├── AutoODE.py ├── DNN.py ├── Graph.py ├── __init__.py ├── __pycache__ ├── AdjMask_SuEIR.cpython-36.pyc ├── AdjMask_SuEIR.cpython-37.pyc ├── AutoODE.cpython-36.pyc ├── DNN.cpython-36.pyc ├── DNN.cpython-37.pyc ├── Graph.cpython-36.pyc ├── Graph.cpython-37.pyc ├── LSTM.cpython-36.pyc ├── __init__.cpython-36.pyc ├── __init__.cpython-37.pyc ├── graph_models.cpython-36.pyc ├── neural_odes.cpython-36.pyc ├── seq_models.cpython-36.pyc ├── train.cpython-36.pyc └── train.cpython-37.pyc ├── mobility ├── Mobility.py ├── us_beta.pt └── us_graph.pt ├── population_states.csv ├── torchdiffeq ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-36.pyc │ └── __init__.cpython-37.pyc └── _impl │ ├── __init__.py │ ├── __pycache__ │ ├── __init__.cpython-36.pyc │ ├── __init__.cpython-37.pyc │ ├── adams.cpython-36.pyc │ ├── adams.cpython-37.pyc │ ├── adaptive_heun.cpython-36.pyc │ ├── adaptive_heun.cpython-37.pyc │ ├── adjoint.cpython-36.pyc │ ├── adjoint.cpython-37.pyc │ ├── bosh3.cpython-36.pyc │ ├── bosh3.cpython-37.pyc │ ├── dopri5.cpython-36.pyc │ ├── dopri5.cpython-37.pyc │ ├── dopri8.cpython-36.pyc │ ├── dopri8.cpython-37.pyc │ ├── dopri8_coefficients.cpython-36.pyc │ ├── dopri8_coefficients.cpython-37.pyc │ ├── fixed_adams.cpython-36.pyc │ ├── fixed_adams.cpython-37.pyc │ ├── fixed_grid.cpython-36.pyc │ ├── fixed_grid.cpython-37.pyc │ ├── interp.cpython-36.pyc │ ├── interp.cpython-37.pyc │ ├── misc.cpython-36.pyc │ ├── misc.cpython-37.pyc │ ├── odeint.cpython-36.pyc │ ├── odeint.cpython-37.pyc │ ├── rk_common.cpython-36.pyc │ ├── rk_common.cpython-37.pyc │ ├── solvers.cpython-36.pyc │ ├── solvers.cpython-37.pyc │ ├── tsit5.cpython-36.pyc │ └── tsit5.cpython-37.pyc │ ├── adams.py │ ├── adaptive_heun.py │ ├── adjoint.py │ ├── bosh3.py │ ├── dopri5.py │ ├── dopri8.py │ ├── dopri8_coefficients.py │ ├── fixed_adams.py │ ├── fixed_grid.py │ ├── interp.py │ ├── misc.py │ ├── odeint.py │ ├── rk_common.py │ ├── solvers.py │ └── tsit5.py └── train.py /Evaluation.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import torch\n", 10 | "import torch.nn as nn\n", 11 | "import time\n", 12 | "import numpy as np\n", 13 | "import time\n", 14 | "import pandas as pd\n", 15 | "from torch.utils import data\n", 16 | "import matplotlib.pyplot as plt\n", 17 | "import random\n", 18 | "import os\n", 19 | "import warnings\n", 20 | "warnings.filterwarnings(\"ignore\")\n", 21 | "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")" 22 | ] 23 | }, 24 | { 25 | "cell_type": "markdown", 26 | "metadata": {}, 27 | "source": [ 28 | "## Inverse scale of the prediction" 29 | ] 30 | }, 31 | { 32 | "cell_type": "code", 33 | "execution_count": 2, 34 | "metadata": {}, 35 | "outputs": [], 36 | "source": [ 37 | "# Read and Preprocess the csv files from John Hopkins Dataset\n", 38 | "# https://github.com/CSSEGISandData/COVID-19/tree/master/csse_covid_19_data/csse_covid_19_daily_reports_us\n", 39 | "direc = \".../ODEs/Data/COVID/\" # Directory that contains daily report csv files.\n", 40 | "list_csv = sorted(os.listdir(direc))#[:-6]\n", 41 | "us = []\n", 42 | "for file in list_csv:\n", 43 | " sample = pd.read_csv(direc + file).set_index(\"Province_State\")[[\"Confirmed\", \"Recovered\", \"Deaths\"]].sort_values(by = \"Confirmed\", ascending = False)\n", 44 | " us.append(sample.drop(['Diamond Princess', 'Grand Princess']))\n", 45 | "us = pd.concat(us, axis=1, join='inner')[:50] \n", 46 | "us_data = us.values.reshape(50,-1,3)\n", 47 | "us_data[us_data!=us_data] = 0\n", 48 | "us_data[:,:,1] += us_data[:,:,2]\n", 49 | "us_data_diff = np.diff(us_data, axis = 1)\n", 50 | "\n", 51 | "# standardization\n", 52 | "std = np.std(us_data_diff, axis = (1), keepdims = True)\n", 53 | "avgs = np.mean(us_data_diff, axis = (1), keepdims = True)" 54 | ] 55 | }, 56 | { 57 | "cell_type": "code", 58 | "execution_count": 7, 59 | "metadata": {}, 60 | "outputs": [], 61 | "source": [ 62 | "results = torch.load(\"/global/cscratch1/sd/rwang2/ODEs/Main/Results/autofc.pt\", map_location=torch.device('cpu'))\n", 63 | "idx = 0\n", 64 | "print(results[\"preds\"].shape, results[\"trues\"].shape)\n", 65 | "preds = results[\"preds\"][:,:,:,1].reshape(-1,50,7,3) * std + avgs\n", 66 | "trues = results[\"trues\"].reshape(-1,50,21,3) * std + avgs\n", 67 | "\n", 68 | "w1_pred = np.cumsum(np.concatenate([np.expand_dims(us_data[:,132 + idx*7], axis = 1), preds[idx]], axis = 1), axis = 1)[:,-7:]\n", 69 | "w1_true = np.cumsum(np.concatenate([np.expand_dims(us_data[:,132-14 + idx*7], axis = 1), trues[idx]], axis = 1), axis = 1)[:,1:]\n", 70 | "\n", 71 | "preds = results[\"preds\"][:,:,:,2].reshape(-1,50,7,3) * std + avgs\n", 72 | "w1_pred_up = np.cumsum(np.concatenate([np.expand_dims(us_data[:,132 + idx*7], axis = 1), preds[idx]], axis = 1), axis = 1)[:,-7:]\n", 73 | "w1_true_up = np.cumsum(np.concatenate([np.expand_dims(us_data[:,132-14 + idx*7], axis = 1), trues[idx]], axis = 1), axis = 1)[:,1:]\n", 74 | "\n", 75 | "preds = results[\"preds\"][:,:,:,0].reshape(-1,50,7,3) * std + avgs\n", 76 | "w1_pred_lower = np.cumsum(np.concatenate([np.expand_dims(us_data[:,132 + idx*7], axis = 1), preds[idx]], axis = 1), axis = 1)[:,-7:]\n", 77 | "w1_true_lower = np.cumsum(np.concatenate([np.expand_dims(us_data[:,132-14 + idx*7], axis = 1), trues[idx]], axis = 1), axis = 1)[:,1:]\n", 78 | "\n", 79 | "\n", 80 | "preds_total = []\n", 81 | "for i in range(2,6):\n", 82 | " results = torch.load(\"/global/cscratch1/sd/rwang2/ODEs/Main/Results/AutoODE_0823_\" +str(i) + \".pt\", map_location=torch.device('cpu'))\n", 83 | " preds_total.append(results[\"preds\"][:,-7:])\n", 84 | "preds_total = np.array(preds_total)\n", 85 | "ode_std = np.std(preds_total, axis = 0)" 86 | ] 87 | }, 88 | { 89 | "cell_type": "code", 90 | "execution_count": 17, 91 | "metadata": {}, 92 | "outputs": [ 93 | { 94 | "data": { 95 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAgwAAAGSCAYAAACPApmhAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy8QZhcZAAAgAElEQVR4nOzdd3iUxb7A8e8kpNNSCASihF4EiVIERDpEBA4WUCJB8AqKHiyASpESVATOUcAjchU5CIIgqNgABVEQxXApCkiVFnpL6KQnc/+Y3WU3uyG9we/zPPssO+/svLOFvL+dqrTWCCGEEELciFtxV0AIIYQQJZ8EDEIIIYTIlgQMQgghhMiWBAxCCCGEyJYEDEIIIYTIlgQMQgghhMiWBAxClEJKqYFKKa2UGljcdRE3ppRap5SS+eui1JOAQdwUlFI+SqkkpdQ0u7TZSqnLSqkyLvK3t1xw7W8JSqlTSqn1Sql/K6XuKtpX4bJ+0cVVh+wopaLt3rtPbpCvnV2+2CKsoigAdp9z+yyOz7McDyvSiokiJwGDuFncC3gBP9uldQLWa63TbvC8I8BEy20GsBzwA14G/lBKfaqUKls4Vb5ppAG9lVIVszg+2JJHCFGKOf3yEqKU6gikA+sBLL92agLvZ/O8WK11dOZEpVQ48AnwOBAAdCu4qt50lgMPAv3I9H4rpfyBR4DvgIeKvmpCiIIiLQyiVFJKlVNK1bbegK7AHiDY8vhRS9bDdvl8clq+1nob0Bk4B9yvlHrQRR3qW5pjjymlUpRSZ5RSi5RS9VzkrauUmqKU2qKUOqeUSlZKHbF0m4RmyjsPWGt5OCFTt0l7F2V3sPSTX7F0waxQSjXI6WstAD8Ax4FBLo71B7yBj1w9USnlqZQaqpRaaXk/kpVS55VSa5RSLoM0pdSdSqnFSqlYS/5zSqk/lFIzlFIedvnKKaXGKaV2Wt6XK0qpg0qpJUqpppnKHKiU+lIpdUgplWjJv0EpFZXVi1ZKBSilJlnKT1BKXVJKbbd8zn4u8pdRSo1RSu231PuYUmqqUsozU74wy2c9L4vzOo2JUMYApdTvlvcjyVL+KqXUYy7KCFVKzbS83mSlVLxS6lulVPNM+WKBCZaHa+2/i5bjGhhgOX7YVdeTUqqm5Xt+wPLenldK/aWU+kApFZjV+ytKHmlhEKXVI8DHLtL3Z3q8zO7fHYB1OT2B1vqsUupDYCzm1/PX1mNKqfstZXtgfj0fAEKBh4HuSqkOWus/7Ip7GBiCCQR+B1KAOzAX2Z5KqWZa6xOWvNbzDAB+yVTn2EzV7AH0Ar4HPgAaAg8AzZVSDbXWcTl9vfmQDswFxltexxa7Y4OBw8CaLJ4bALyLeU9+xARoIUBPYKVSarDWeo41s1LqTuD/AA18aym7PFAbeA7zWaUqpRQmkGkNxABzMN0ioZjvwa/AVrt6/C+wC9NCdQoIxLyPC5RS9bTW4+wrrZSqgfksq1vK+V/MD7C6wDDMZ3Et02tdBNyH+awuW8p/FQgGnszi/cmpScBozPuxFLiEeR+bA32AJXZ1vxtYjXnvV2G+x0GYVqLflFIPaa1XWrLPsKS3A+bj/P2baDneBPM5XrSkX7ScKwTYjPmMVgJfYgLIGphgciYQn8/XLoqK1lpucit1N8wf6t6W2zTMBWScXdo1zHiG3na3SnbPb295zrpsztPJku+IXZo/cAGIAxpmyt8IuAr8kSm9GuDlovyumAvu/2ZKt9YvOot6DbQcTwM6ZTo22XLs1UL+DKIt5xlk+TzSgQ/tjre0HH8N8+NEY7qA7MvwAkJdlF0B2AmcB3zs0t+xlNPLxXP8ATfLvxtb8n3lIp8b4J8prZaLfJ7AT0AqUC3Tsd8t5Y928bwgwNvu8TpL3q1AgF26HybQTAeq2KWHWfLPy+J9X2f+dDukxWNaeXxd1cfu32Us50wC2mXKVxU4gQmYvOzSrZ9z+yzqM89yPMzFsectx150cczP/rOVW8m/SZeEKJW01ke01l9orb/A/EFKBaZZHu8AfIHPrXkst3N5OJX1V38lu7QngIrABK317kz12olpfr9LKdXQLv2E1jrZxetYjfllG5GHugF8prX+KVPabMt9izyWmWta6yOYX62Rds3xgzEXQ1ctQdbnJWutj7tIv4RptfDH/ErOLNHFcy5orTNykC9Da30hU9pBF/lSMGMyymACRwAs3RmtgG3AVBfPi9NaJ7mo80it9Xm7fNeATzEBTDMX+XMrFfN+O9XH7mF3oBbwntb6l0z5TgL/Aqpg93oLiKvP4ZrW2ildlFzSJSFuBh2BzZY/wGCaT8E05+eXstzb9xm3stw3Ua6nPda13DcAdoPpY8Z0awzENN/6A+52z0nJY/22uEg7Zrn3z+7JysxseMnFoRla64su0m/kI+B+oK9S6nPgMWCF1vqkcjG11a4OdwCvAG0xzejembJUs/v3EuBF4Gul1BeYro4NLi74uzEX9EilVHXgG+A3YIslEMhch9uBkZgL5e1A5vEu9nVoablf5SJAuZF8fVbZ+BTza363Umop5rsfYwm87Fm/u9Wz+O7Wsdw3wHQh5Ne3wFvA+0qpCEwXyAZgt9Za1qYoZSRgEKWOZeBfe8tDN8wFeIvdH8AHML+0HjXXadAuZkLkUFXLvX3rhHWg1uBsnms/HXMa5sJ8CvNH8wTXf3UNxDTp54XTRV1rnWZ53e7O2Z1U5PqgNnvzXJWdje+AM5guCg9Mk7PLwY5WSqmWmK6jMpjm/28x/fsZQDhmfIaXNb/WepNS6j5MN0dvTD84Sql9wESt9WJLvnSlVEdgvCWftSXgilJqPqYr4arluTWBTZiL9q+YlpJLmO9QGGYsia0OmPcMrrc+5UgWAZh1umlOPqsbGQYcwoyFGGW5pSmlVgIjtNYHLPms390+2ZRXIFOJtdZHlFItMN0a92PG8gAcU0q9rbX+T0GcRxQNCRhEadQe54tcc5ybru3zROfxXB0s9/9nl2b91dZEa70juwKUUsHAC5g++dZa6yuZjkfmsW75prWO5XorSn7LSlVKfYy5WIVi+tS/z+ZpYzG/5jtordfZH1BKjcYEDJnPEwP0UEp5AU0xF6LngUVKqXNa6zWWfBcwF9JhysycaQc8AwzFXPT7W4ocjrmQPqm1npepDpFcnwVgZb3wV6NwWFstsvr77LTehdY6HTNAcYbl+9YG6IsJDO5QSt1h6RKzfnd7aa2/Ldhqu6a13gM8ZmllaoKZffQ88K5S6prW+r9FUQ+RfzKGQZQ6WutorbXSWivMILhkzOAphWlKBXjWmseSnmuWP7zPWB5+andoo+X+vhwWVRPzf221i2Ah1HI8M2tfdH5/eRa1OZjum1BgruVCdiO1gfOZgwWLdi7SbCzjH37XWo/HBGTgIsCw5D1guTC1wwxKtc9X23L/ZQ7rYP38I5RShfE31Dq+4rbMB5RS5bne5eWS1vqs1nqZ1vpRTOtNLcxgXMj9dxey/y7m6LuqtU7TWm/VWk8FrEGy03RlUXJJwCBKuw7ARrtBZu0t9+vyU6hSqglmml8QsDLTr7GPMb8yJ1iaWzM/1005rpcQa7lvo5Ryt8tXFtNk7+qXpHWq2e15fQ3FwTKW4H7MIk05aW6OBQIs0yVtlFJP4WIgqFKqtXK9nkZly32CJV8NS1dDZv6Y7gX7wXaxlvv2mc4VgYu1JbTWWzGzJMIx4x4y1zFQKZV5HEaOWYLKvcC99gNnLd+daWQaX6GU8lJK3euiHh6YqZNgeV8wYzkOAv9USj3g6vxKqVZKKV+7pOy+i1keV0o1VUpVcPEch89LlA7SJSFKLcuAvXDgDbvk9sBprfXeHBYTZjf2wQMTIDS13AAWYtZPsNFaxyulegNfARuVUj9hZjpozK/CVpgmbm9L/tNKqc8wTcTblFKrMdMGu2Cmt22zvA57+zB95H2VUqmYJaw1sMAyI6HEssz8yKkZmMDgN8tgvUuYGQNtgC8w4w/svQp0VEr9illz4CpmPYtumF/m1hkiTYBlSqnNmAW9TmJmuvTCfM72sxtmYfr+P7cMpDyJ+UV+P2ZNA6eFj4AoTFD6llLqEcu/FWbQYFegPs5rFuTGv4H/AhssA0iTMMGxB7Dd8vqsfDDv3wHM1M0jmO9eF0yL27eWbgFrt9HDmHE0K5RSv2O+fwmY725zTItXCNcv5msx3SSTlVKNsLSAaK3ftBz/CTNo9SOl1JfAFeCi1nomptvnGaXUb5hA5QKmxaMnpmVwRj7eI1HUimr+ptzkVtA3zB9/h/nhmEGFn+Xgue0tz7W/JVqevx7zBzs8mzLCMAvP7Mf8Qb+M+WW4AHgwU15fzOI61jnwxzBT9gJxMa/e8pzmmD/GlzB/sG2vlevrMAzMom7ZrjFRAO9/tOU8g3KQ1+U6DJZjPTBN5VcwLTerMTMmnF4j5mL8MWYWxCXMehv7MK0Z1e3yhWJG528ATmMuTtYxFd1c1KE1pvn+gqUev2Gay63fk2gXzwnEBB77LJ/pRczFdxJ26yFk9flm9zkCT2EC0WTLa/jQ1fcFE0S8anltRy11OWd5T4cAni7KDgamYMbVJGACr/2YIC0KKJMpf5TltSVa6qszHR+OCcyS7T9n4B7MolbbMWtqJGL+D3wMNCrM76fcCv6mLB+qEEIIIUSWZAyDEEIIIbIlAYMQQgghsiUBgxBCCCGyJQGDEEIIIbIlAYMQQgghsiXrMNxAUFCQDgsLK+5qCCFyYd++fQDUq1evmGsiROm0devWOK11pczpEjDcQFhYGFu2uNpgTghRUrVv3x6AdevWFWs9hCitlFIuF4eTLgkhhBBCZEsCBiGEEEJkSwIGIYQQQmRLAgYhhBBCZEsCBiGEEEJkSwIGIYQQQmRLplXm0+XLlzl79iypqanFXRUhbgoeHh4EBwdTvnz54q6KEMJOkQcMSqkHgdeBesBJ4D2t9bRMeZ4DugMtgQCgg9Z6nYuyGgLvAa0we9HPASZqrdPt8ihgNPAsEARsBl7QWm/L72u5fPkyZ86coVq1avj4+GBOJYTIK601iYmJnDhxAkCCBiFKkCLtklBK3QssAzYBPYG5wFSl1EuZsj6BCRRW3aAsf2ANoIFemCBkBDAxU9ZRwDhgquWcV4E1Sqkq+X09Z8+epVq1avj6+kqwIEQBUErh6+tLtWrVOHv2bHFXRwhhp6hbGMYDG7TWgyyPVyulKgLjlVKztNYplvTWWusMpVQjIDKLsoYAPsDDWuvLwI9KqfJAtFLqX1rry0opb0zAMFlrPRNAKRUDxAJDgbH5eTGpqan4+PjkpwghhAs+Pj7SzSdECVPUgx7DgR8zpa0G/DHdCgBorTNyUFY3YJUlWLD6DBNEtLM8bg2UB5balX0N+M7y/HyTlgUhCp78vxKi5CnqgMEbSMmUZn3cIJdl1Qf22idorY8CCZZj1jzpwP5Mz91jl0cIIYQQ2SjqgOEA0DxTWgvLfUAuy/LHDHTM7ILlmDXPVftBkHZ5fJVSnrk8pxBCCFFiXbsGV64UTtlFHTB8ADyolBqslPJXSkUAwy3HctINUeiUUk8rpbYopbacO3euuKtTpObNm4dSyuVtzZo1DnmPHj3K0KFDqVOnDt7e3pQtW5bmzZszadIkLl26VEyvIGesrzM2NtaWFhYWxsCBA3Ndzty5c3NUvhBCFKaMDIiNhbVr4fjxwjlHUQ96nAs0Af4XmI3pPhiJmRp5OpdlXQAquEj3txyz5imrlHLP1MrgDyTYDbK00VrPttSNZs2a6VzW6abw+eefExoa6pDWsGFD27/Xr1/PP/7xD4KDg3nhhRdo1KgRqampbNy4kffff5+4uDimT59e1NXOl6+++irXU/jmzZtHWloa//M//+OQ3r17d2JiYggJCSnIKgohhEuXL8Nff5l7b+/CO0+RBgyWi/ZQpdQ4IBQ4zPWxBBtzWdxeMo1DUErdBvhyfWzDXsAdqA3ss8vqNP5BXBceHk7t2rVdHrtw4QK9e/emQYMGrFmzBj8/P9uxrl27MmLECH7//fdCq1tycjJeXl4FXu5dd91VYGVVqlSJSpUqFVh5QgjhSloaHD4M+/eDnx8EB0NhNvAWy9LQWusLWuu/tNZXgeeA37XWub2Afw9EKKXK2aU9BiQCv1ge/w5cBvpYMyilfDHrMXyf1/rfyubMmcO5c+d47733HIIFKz8/P7p06XLDMpRSvPbaa0yaNInQ0FB8fHxo27Yt27Y5rqXVvn172rRpw3fffcddd92Fl5cXs2bNAiAtLY3JkydTv359vLy8qFq1KiNGjCApKcmhjEOHDtG9e3d8fX2pVKkSL774IsnJyU51ctUlcfjwYfr370+VKlXw8vKiZs2avPjii7a6/fLLL2zYsMHWbdO+fXvAdZdEamoqY8eOJSwsDE9PT8LCwhg7dqzD1MHY2FiUUnz44YeMHz+ekJAQKlasSM+ePTleWG2MQohS6cIF2LABDh6ESpWgbNnCP2eRtjAopVoCbYBtmOmOkUCEJc0+XzMgDLjNktROKRUExGqtt1jSPgBeAJYppaYCNYFoYJp1qqXWOkkpNQUYp5S6gGlVGI4JlN4rpJcJQHQ0TMy8hFQWBg+G2bMd055+Gj76KGfPnzDBnM/e1q3QtGnOnp9Zeno6aWlptsdKKdzd3QH48ccfCQkJoVmzZnkr3OKTTz7h9ttvZ+bMmSQnJzN+/Hg6derE/v37CQi4Pv7177//5oUXXmDcuHHUrFnTdiwqKorvvvuOkSNH0rp1a/bs2cO4ceOIjY3lyy+/BCAlJYUuXbqQmJjI+++/T3BwMB9++CHLli3Ltn6HDx+mRYsW+Pr68vrrr1OnTh2OHj3K6tWrAZg1axZRUVGkp6fz4YcfAjdelXDAgAEsXbqUMWPG0KZNG37//XcmTZrEoUOHWLRokUPeyZMn07p1a+bOncvZs2cZMWIEUVFRrFu3LlfvsRDi5pOaaloUDh+GChVMsFBUinoMQyqmFSAaM8jxV+BerfVfmfINBQbYPY623M8HBoJppVBKdQJmYtZVuAhMt8trNQUTIIwGAoEtQBet9ZkCeD03pfr1HWec3nvvvfz2228AHDt2jOrVq+f7HImJiaxevdrWSnHPPfdQp04dpk+fzhtvvGHLFxcXx+rVqwkPD7el/frrryxZsoT58+fzxBNPANC5c2cCAgKIiopi27ZthIeHM3/+fA4dOkRMTAwtW7YEoFu3bjRu3Djb+k2YMIHExES2b99O1apVbekDBpivZcOGDSlfvjxpaWm2srOyc+dOFi9ezIQJE4i2RHZdu3alTJkyjBs3jlGjRnHnnXfa8oeFhTkEEefOneOVV17h5MmTDnURQtxazp6FnTtNV0TlylDUy5UU9RiGrThPq3SVbyCWwCCbfLuBjtnk0cAky03kwFdffeUw6LFcuXI3yJ03DzzwgEOXRlhYGC1btiQmJsYhX1hYmEOwAPDDDz/g6elJ7969HVpCunbtCphBmeHh4cTExHDbbbc5XNDd3Nx49NFHbRfurKxevZoePXoUyAV6/fr1gGkVsRcVFcW4ceP45ZdfHAKGBx54wCGfNcA5evSoBAxC3IKSkuDvv+HYMahY0bQsFAfZrbKQREc7dxPkxuzZzt0UuZHX7giARo0aZTno8bbbbmPnzp15L9yicuXKLtN27drlkOZqpsHZs2dJSUlxOYYCID4+HoBTp05leZ7sxMfHO80Uyavz588Dzq+lSpUqDset7LtkANsgz8zjM4QQNzet4dQp2LXLtCYUR6uCPQkYRK507tyZH3/8ka1bt9I0H1HJmTPOPULWnT/tuVoiODAwEG9vb3799VeXZVt/hYeEhDgFIFmdO7OgoCDbjon5ZQ0ATp8+Ta1atWzpp0+fdjguhBBWCQmwe7fphvD3B88SsMxgscySEKXXoEGDCAoKYujQoVy7ds3peEJCgtMiT66sXLnS4fmxsbFs3LiRVq1a3eBZxv33309SUhKXLl2iWbNmTjdrwNCqVSuOHTvGxo3XZ+xmZGSwdOnSrIq26dq1K8uXL+fUqVNZ5vHy8iIxMTHbstq2bQvAZ5995pD+6aefAthmVwghREYGHD0K69ebKZKVK5eMYAGkhUHkUkBAAF9++SX/+Mc/uPvuu3n++edtCzdt2rSJDz74gN69e9O5c+cbluPj40PXrl155ZVXSE5OZsKECZQvX55hw4ZlW4f27dsTGRlJ7969GT58OC1atMDNzY3Y2FhWrlzJ1KlTqVu3LgMGDGDKlCk8/PDDvPXWWwQHB/PBBx9w+fLlbM8xceJEVq5cSevWrRkzZgy1a9fmxIkT/PDDDyxcuBAwAx9nzZrFkiVLqFWrFuXKlaNevXpOZTVq1IjIyEiio6NJS0ujdevWxMTE8MYbbxAZGZmjQZhCiJvflStmUOPFixAQAGVK2BW6hFVHlAZt27Zl+/bt/Pvf/2b69OkcP34cDw8PGjRowD//+U+ee+65bMt44okn8PPzY+jQocTFxdG8eXM+++yzHDfPL1y4kPfee4+5c+cyadIkvLy8CAsLIyIiwjZGwdPTkx9//JGhQ4fy3HPP4efnx+OPP0737t0ZMmTIDcsPCwtj48aNjB07ltGjR3P16lWqVatGr169bHlGjhzJvn37GDRoEFevXqVdu3ZZTn2cN28eNWvWZO7cubz55ptUrVqVkSNHMmHChBy9XiHEzSs9HY4cgX37wMfHLMBUEikziUC40qxZM71ly5Ysj+/Zs4cGDXK7yaawLtz05ptvFndVRAmW1/9f1i4eWbdClAYXL5pWhatXTauCZcmbPLt0CUJCID+XJqXUVq2102I70sIghBBCFLHUVLNK46FDZpXG0rCavAQMQgghRBGKj4cdOyAlxQQKbqVk+oEEDKLISTeYEOJWlJJiFmA6etQsvpTLDXKLnQQMQgghRCHS2qyn8NdfZtpkcHDxLsCUVxIwCCGEEIUkMRH27DErNvr7g2Xh1lJJAgYhhBCigGkNJ06Y1Rrd3MCyEnypJgGDEEIIUYCuXTP7P8THl8wFmPLqJnkZQgghRPGLi4MtW0zXQ0ldgCmvJGAQQgghCsDFiyZYKF++dI9VyEopmf0phBBClFxXrsDmzVCu3M0ZLIAEDCILgwcPRimVo82gbmTbtm1ER0dz/vz5fJWzb98+BgwYQLVq1fD09KRatWr079+fffv2OeUdOHAgSimUUri5uVGhQgUaNmzIU089RUxMjFP+devW2fK7ul28eDFfdS9s0dHRTtuAK6WIjo7OVTkzZsxg2bJlOSpfCHFdQoIJFry8wNu7uGtTeCRgEE4SExNtW0AvWrSItLS0PJe1bds2Jk6cmK+AYc2aNdx9991s376dt956izVr1jB58mR27drF3Xff7XI77UqVKhETE8Pvv//OsmXLGDp0KHv27LHtPunKf/7zH2JiYpxu5cqVy3Pdi0tMTAyDBg3K1XOyChgGDRrkMtASQkBysumGUAr8/Iq7NoVLxjAIJ19//TWXL1/mgQceYOXKlfzwww/06NGjWOoSHx9P3759adKkCT///DPelvC9bdu2PProo3Ts2JG+ffuyb98+AgMDbc/z9PSkZcuWtsedOnXi2WefZdiwYUyePJmmTZvyyCOPOJyrQYMGDs8pbFprUlNT8SyEze4L8nWEhoYSGhpaYOUJcbNITYU//jD3FSsWd20Kn7QwCCfz58/H39+fefPm4ePjw/z5853yDBw4kLCwMKf09u3b23YLnDdvHk8++SQAderUsTXxx8bGAnD58mWGDh1K1apV8fLyol69ekyfPt1h6eg5c+YQHx/Pu+++awsWrLy9vZkxYwbx8fHMmTMn29ellOJf//oXlStXZsaMGTl8N24sNjYWpRSzZs1i+PDhBAcH4+vrS48ePWyv0yosLIyoqCjmzp1L/fr18fT0ZMWKFQAkJCQwcuRIatSogaenJzVq1GDSpElkZGQ4lPHnn39y33334e3tTbVq1XjjjTdcLrXtqkti+/btPPTQQwQGBuLj40O9evWYPHmyrW5Hjhzh008/tX1OAwcOBFx3SeTks7N29Xz77bcMHTqUoKAggoKCiIqKKvHdPEJkJz0dtm83YxduhWABpIWh8ERHw8SJOcs7eDDMnu2Y9vTT8NFHOXv+hAnmfPa2boWmTXP2fDsnT55kzZo1DB48mEqVKvHggw+ybNkyLly4gL+/f67K6t69O2PHjuXNN9/k888/t/1KDQkJISMjg+7du/PHH3/w+uuv07hxY1asWMHw4cM5d+4cb731FgA//fQTVapUoXnz5i7P0aJFCypXrszPP//MyJEjs62Tp6cnnTp14osvviAtLY0ydhOkMzIynLpflFK452C/2cmTJxMeHs7HH3/M2bNnGTNmDF27dmXXrl14eHjY8q1du5Zt27YxYcIEgoODCQsLIy0tjYiICHbv3s24ceNo3LgxGzdu5I033uD8+fO88847AMTFxdGxY0eqVKnC/Pnz8fLy4t///jdHjx7Ntn6bNm2iffv21K5dm+nTpxMaGsr+/fvZsWMHAF999RUPPPAATZo0sQUalbLYPi+nn53Viy++SI8ePVi0aBH79u3j1Vdfxd3d3WUgKkRpkJFhtqSOiysdu0wWFAkYhIOFCxeSnp7OE088AcCAAQNYvHgxS5YsYciQIbkqq1KlStSqVQuA8PBwateubTu2fPlyfvvtNz7++GPbL9muXbty7do13nnnHYYPH05QUBDHjh1z2ZJhLywsjGPHjuW4XrfffjspKSnEx8dTuXJlW3pERIRT3jvuuIOdO3dmW2a5cuX45ptvcLNsO1e3bl3atGnDJ598wlNPPWXLd+HCBbZu3UoVu2XfFixYwG+//cYvv/xC27ZtAdOFAjBx4kRGjhxJcHAw06dP59q1a6xevZrbbrsNgC5dulC9evVs6/fyyy8TGBjIxo0b8fX1BaBjx46243fddRdeXl4EBQVl252xcuXKHH12Vm3btuW9996z5du3bx9z5sxh3rx5MphSlDpaw969ZhVHuz8ftwTpkhAO5s+fT506dWjVqhUAnTt3pmrVqgX+a3D9+vW4ubnx+OOPO6RHRUWRkpJSqIPsrM3mmS9W77//Pm8kZ6UAACAASURBVJs3b3a4LVmyJEdl9u7d2xYsANx7772EhoY6vY6WLVs6BAsAP/zwA9WrV6d169akpaXZbl27diU1NZWNGzcCZiBjy5YtbcECgJ+fHz179rxh3RISEtiwYQP9+vWzBQv5kdvPrnv37g6PGzduTHJyMmfOnMl3XYQoagcOwOHDN9+iTDkhLQyFJTrauZsgN2bPdu6myI08dEds2bKF3bt3M3LkSIc+5ocffpiZM2fy999/U7du3bzXyc758+cJCAhwGvBnvZhaZ1WEhoZm+ws/NjaWJk2a5Pjcx44dw9PTk4CAAIf0unXr0qxZsxyXY6+yi58alStX5sSJEw5pISEhTvnOnj3LkSNHHLou7MXHxwNw6tQpGjVqlKNz27tw4QIZGRkFNnAxp5+dVeb32csyST0pKalA6iNEUTlyxGxPXVp3m8wvCRiEjbUVYerUqUydOtXp+CeffMKbb74JmAGHKSkpTnni4+MdZitkJSAggPPnz5OSkuJw4Tl9+rTtOJim+TVr1rB582aX4xg2bdrEmTNnHJrXbyQlJYU1a9bQsmVLh/EL+eXq1/KZM2cIDw93SHPVBB8YGEiNGjVsU1kzs3bJhISEZHmeG/H398fNzc0peMmrnH52QtxMTp0y4xaCgsxmUreiW/Rli8xSUlJYvHgx99xzD2vXrnW6hYeHs2DBAltzfvXq1Tlz5gznzp2zlXHw4EGnhZSsvyYTExMd0tu1a0dGRgaff/65Q/qnn36Kp6enrUtk0KBB+Pv78+KLLzr9Ik1KSuKll14iICAgR2sOaK159dVXOXv2bL4XpMrsiy++cJjRsGHDBo4fP257HTdy//33c+zYMcqWLUuzZs2cbtbxAK1atWLjxo0O4zWuXbvGd999d8PyfX19adOmDQsXLnT6HOx5eXnd8LhVTj87IW4WcXHw558QGHjzbCSVF7fwSxf2VqxYQXx8PO+8845tWqS9Z555hmeffZZ169bRoUMH+vTpw7hx44iKimL48OHExcUxefJkh8FuAA0bNgTM+IABAwbg4eHBnXfeSbdu3WjTpg1Dhgzh3Llz3HHHHaxcuZI5c+YwevRoWzlBQUEsXryYhx56iFatWjFs2DBq1KhBbGws06dPZ+/evXz11VdOrRopKSm2vv+EhAT27dvHwoULiYmJYezYsTz44INOr3HPnj2ULVvWKb1x48b4ZbMiy5UrV3jwwQd55plnOHfuHKNHj6ZOnTq2waM30q9fPz7++GM6derEiBEjaNKkCSkpKRw8eJBvv/2Wr7/+Gl9fX4YNG8asWbPo2rUr0dHRtlkSPj4+2Z7j7bffpl27drRq1YoRI0YQGhrKoUOH2LZtm21AYsOGDfn1119Zvnw5VapUISgoyOWA05x+dkLcDKz7Q1SsCFn0Gt4yJGAQgOmOKFeuHH369HF5PDIykuHDhzN//nw6dOhA7dq1+eKLL2wX37p16zJt2jSnKXXWaXqzZ8/mo48+IiMjg8OHDxMWFsaKFSsYM2YMU6dOJT4+nrCwMKZNm8ZLL73kUEZERARbt27lrbfeYtSoUcTFxREYGEjHjh1ZuHChLSixd+7cOVq1aoVSCj8/P0JDQ2ndujXvvPNOlrMAXnjhBZfpmzdvznZsw+jRozlw4AADBw7k2rVrdOjQgZkzZ2Y5LsGeh4cHq1atYsqUKcyePZvDhw/j5+dHrVq16N69u63ZPygoiJ9++okXX3yRAQMGEBgYyJAhQ0hLS+P111+/4TmaN2/Ohg0bGD9+PM8//zzJyclUr17dtk4GmKmhgwcP5tFHHyUxMZEBAwYwb948p7Lc3Nxy/NkJUZpdvXrz7w+RG8rVoi/CaNasmd6yZUuWx/fs2UODBg2KsEaipImNjaVGjRp89NFHuV6KWdxYXv9/WVvI1q1bV7AVEreUhATYuNF0QZSmJZ8vXYKQEMjPpUkptVVr7fQrScYwCCGEEHZupf0hckMCBiGEEMLCfn+I8uWLuzYli4xhECIfwsLCXO7lIIQofez3h8jB7PBbjrQwCCGEuOXZ7w8hwYJrEjAIIYS4pdnvD3ErbSaVWxIwCCGEuKXdyvtD5IYEDEIIIW5Zt/r+ELkhAYMQQohb0s20P8Rff8HIkbBrV+GdQ2ZJCCGEuOXcDPtDpKfDL7/AwoWwY4dJS0mBzp0L53yl9G0SQggh8qa07w+RmAjffguLF8Px447HNmyAs2fzt9JjVkp5I4woKEqpbG+uNiIqCV577TW6du0KQGpqKt7e3vz8888Oefbu3YtSioULF+bpHOvXr6dZs2b4+vqilGLv3r35rrdVXFwc0dHR7LD+RLDTsmVL7r///gI7lxC3Ouv+EGXLlr79IeLi4P33oXt3+Pe/HYOFMmWgRw/44IPCG7wpLQwCgJiYGIfHDz30kG3jKCuvEvq/a+vWrTRt2hSAv/76i+TkZO6+++4CPcfAgQMJDg5mxYoV+Pj4FGjwFBcXx8SJE6lduzZ33nmnw7H//ve/uLu7F9i5hLiVJSbCpk3g6Qk52OS1xFm2DD7+2DGtfHl45BF49FEzJfTSpcI7vwQMAsBpB0cvLy+CgoKy3Nkxs+Tk5GILKLZu3Wrb+GnLli3UqlWLihUrFlj5SUlJHD58mEGDBtGhQ4cCKzcn7rjjjiI9nxA3q+Rk07KglGldKOm0dp610acPzJ9vXku1avD449CzJ/j6Fk2dpEtC5Frfvn2pXbs269evp2XLlvj4+DB+/HiSkpJQSjFlyhSH/NbugM8++8whfc2aNbRv356yZctStmxZunfvzp49e3JVl6NHjxIXF2drYdiyZYvt39kZNWoUZcqUYf/+/URERODn50eNGjWYPHmybbnnDz74AB/LT5HXXnsNpRT169fP9WtYunQpLVu2xM/Pj/Lly9OyZUu+//579u7da9uRsX///rbuH+t7Zd8lceTIEdzc3Jg9e7ZT+a+//jre3t5cvHjRlrZkyRJatGiBr68v/v7+9O3blxMnTjg8b968eTRp0gQ/Pz8qVKhAkyZNmDt3bo7ePyFKi9K0P0RqKixfDv36mSmf9vz94dln4V//Mq0Njz1WdMECSMBQKHIyHqCoboUlLi6O/v3788QTT/D999/Tu3fvXD1/2bJlREREEBQUxKJFi1iwYAHnzp2jbdu2nDp16obPtQYmSimqV68OQM2aNVFK8dFHH7F06VKUUnh7e2dbD601Dz/8MN26deObb76hW7dujBkzxnbBfvjhh23jIZ577jliYmJYsmRJrl7D22+/zWOPPcbtt9/OwoULWbp0KT179uTIkSOEhYXZzhUdHU1MTAwxMTF06dLFqa7Vq1enbdu2LsdhLFy4kJ49e9paVmbMmEFkZCR33XUXX375JbNmzWLr1q106NCBhIQEAH766SeefPJJunTpwrfffsvSpUsZOHCgQ9AhRGlnvz9EATY8FrhLl0x3Q8+eEB1t1oZYtMg5X1QUdOwIxdFTWeRdEkqpB4HXgXrASeA9rfW0THkUMBp4FggCNgMvaK232eUZCGTqzQHgWa31B7kpS+TepUuXWLJkCREREba0pKSkHD03IyODF198kYiICL744gtbert27ahZsybvvvuuUyuFPS8vL/78808AJk6cSGpqKm+++Sbx8fF07tyZr776irCwMNxyMLE6IyODMWPGEBkZCUCnTp1Ys2YNixcvJjIykuDgYHwtIfxtt91m66LJ6WuIj49n3LhxREZGssjuf7/9QMYmTZoAUKtWrWy7gPr378/gwYOJjY21jaPYuHEj+/fv5+233wbg4sWLvPbaawwZMoRZs2bZntu0aVMaNmzIJ598wpAhQ4iJiSEkJMT2PMDh8xSitLPfH6KkLvl8/LgJDL79FjL/Cf3pJxg+vOQMzizSFgal1L3AMmAT0BOYC0xVSr2UKesoYBww1ZLvKrBGKVXFRbEdgVZ2t2X5KEvkkK+vb54vLrt27eL48eNERUWRlpZmu5UvX57mzZuzfv36Gz5fKUV4eDjh4eEcOHCAiIgIwsPDSUhIoFy5cvzjH/8gPDzcaQBhVrp37+5Q9h133MHRo0cL5DX8+uuvJCUl8fTTT+eoLtnp06cPXl5eDq0MCxYsICgoiG7dutnOmZCQQL9+/RzqVrNmTWrWrGmrW/PmzTl16hQDBw5k5cqVXL58uUDqKERJUNL3h9i+HV55BR56CJYudQwWKlWCoUNNt0NJCRag6LskxgMbtNaDtNartdZvAP8BxiulPAGUUt6Yi/xkrfVMrfUaoA+ggaEuytystd5odztrPZCHsgqE1rrE3ApLlSp5j7fOnjUfUb9+/fDw8HC4rVmzhvj4+Bs+Pz09nbS0NOLi4ti1axetWrUiLS2N3377jebNm5ORkUFaWlqO6uLu7k75TJ2aXl5e2baW5PQ1WO9DQ0NzVJ/slC9fnl69etkChtTUVJYsWULfvn3xsEwot9atTZs2TnXbv3+/rU4REREsXryYgwcP0qtXLwIDA4mIiGBXYS4VJ0QRKan7Q1y5Ak8+CU89BWvXmsDGqk4d0x3x7bcwcGDJG29R1F0S4cD7mdJWA8MxrQO/AK2B8sBSawat9TWl1HdAN2BsLs5XkGUJO67GR3h4eODu7k5KSopDeuYAINCyd+w777xD27ZtncrJbuxBtWrVOHPmjO1x8+bNneoBcOrUqXwFNjeS09cQFBQEwIkTJ6hdu3aBnLt///4sWbKEzZs3c+rUKeLj4+nfv79T3RYtWkSdOnWcnm8fIPXt25e+ffty5coVfv75Z1599VW6d+9ObGxsgdRViOJQkveHKFvWdJXYa93aDHJs0aLk1ddeUQcM3kBKpjTr4waYgKE+kA7sz5RvD/CYizIPKqUCgYPANK31h3bHcluWyAd3d3eqVavGzp07HdJXrFjh8Lhx48ZUrVqVPXv2MHz48FyfZ9WqVaSmpjJ16lQuXrzI5MmTuXr1Kh07dmTBggXUq1cPuH6xLgw5fQ333XcfPj4+zJ49m3bt2rnMY52OmpiYmKNzR0REEBwczIIFCzh16hT16tWjRYsWtuNt27bFx8eHQ4cO2cZmZKdcuXL06tWLffv2MXLkSC5fvuzU8iJEaXDkSMnZH+LcOThzBho1up6mlBm4OH48dOtmAoVatYqvjrlR1AHDAaB5pjTrX7oAy70/cFVrnZ4p3wXAVynlqbVOAU5hxiZsAtyBvsAHSilfrfX0XJZlo5R6Gnga4Pbbb8/La7yl9e3bl2nTpjF16lSaNWvG2rVr+fzzzx3yuLu7M3PmTPr06UNCQgKPPPIIgYGBnD59mg0bNlC3bl2GDs26x8g6SHDHjh0MGzaMZs2a8c033xAQEEBkZGSOBjvmV05fQ0BAAK+//jqvvPIKGRkZPPbYY/j6+vLnn39SoUIFhgwZQmhoKOXLl+fTTz+lXr16+Pr6UqtWLfz9/V2eu0yZMkRGRrJw4UKuXr3KuHHjHI4HBAQwZcoURowYwcmTJ4mIiKBcuXKcOHGCtWvX0q1bN3r37s2oUaO4fPky7dq1IyQkhKNHjzJr1ixatmwpwYIodbSGQ4fMuIWgoOLdH2L/frO/w6pVULUqfPGFY/DSoQN8952pZ2lS1PHXB8CDSqnBSil/pVQEpjsCIOMGz3OitV6ltX7TMhbie631AEzXw1ilVJ5fl9Z6tta6mda6WaWSOFKmhJswYQKDBw9m+vTpPPTQQxw+fJh58+Y55XvooYdYu3Yt58+f56mnniIiIoJRo0YRFxfn8Gs5KwcPHuTvv/+2zTb4/vvviYiIKJJgwSqnr+Hll19m0aJFHDhwgMjISPr06cM333xDjRo1ANOFMmfOHE6fPk2nTp1o3rw5q1atuuG5+/fvT3x8PCkpKURFRTkdf+GFF/jiiy/YuXMn/fr1o3v37kycOBGlFI0bNwbMGg/79+/nxRdfpEuXLowZM4auXbvy9ddfF+C7JETh0xr27TPBQnBw8QULsbHwz39CZCSsWAFpaXD0KGQex12mTOkLFgBUYQ6MczqZUu7Au8AQTKtAAjASeA94Ums9Tyn1HGYgpJd9y4BS6hUgWmvtd4Py+2CChlpa60P5KQugWbNmesuWLVke37Nnj23RHSFEwcrr/6/27dsDsG7duoKtkCiRMjJg925zYS7OMQvbtpkpkJknG4WHw3PPQQGvVp+lS5cgJCR/m08ppbZqrZtlTi/SOMxy0R6qlBoHhAKHMeMMADZa7vdigonawD67p9e3HLvhKTLd56csIYQQJVhaGvz1F5w+XbzBwk8/wbhxZmtpMN0PHTuasQr24xdKu2IZEqK1vqC1/ktrfRV4Dvhda229gP8OXMZMfwRAKeWLWUPh+2yK7g3EAdYFNfNTlhBCiBLKutzzmTPFGyx89hmMGnU9WAgIgHnzYMqUmytYgCJuYVBKtQTaANsw0x0jgQhLGgBa6ySl1BRgnFLqAqYlYDgmuHnPrqwvMQMed2BaER6z3F7QWmfkpiwhhBClR3IybN1qtqourqFmWsN//gMLFlxPu/12k1ZAy66UOEU9NCQVc1GPxgxy/BW4V2v9V6Z8UzAX9dFAILAF6KK1PmOXZx/wP8BtgAJ2A09orRfkoSwhhBClQGKi2XUyNRUsS44Ui8wtGnfeCdOmlez9KvKrqMcwbMV5WqWrfBqYZLlllWcMMKYgysoPrXWhbvIkxK2oKAdji9Lj6lUTLEDJuDA//7zpEklJgTffhBzsd1eqFeNM1dLPw8ODxMRE2+ZEQoiCkZiYaFuxUwgwo/83bQJPT7NaYkng5gYTJ5r74tg90kl6On6H9+CRWAYa1M8+fy7J9tb5EBwczIkTJ0hISJBfREIUAK01CQkJnDhxguCStgmAKDbnz0NMDPj4FF+wcOAATJpkZmbY8/AoGcGC+8V4Kq77moBN3+N+/lyhnENaGPLBuhreyZMnSU1NLebaCHFz8PDwoHLlyrLapADg7FnYsgUqVCi+Jv/Nm+Hll+HaNRMcjBxZgvZ8SEvF9+/t+O7aRIZPWZIqVCq0C7sEDPlUvnx5+cMmhBCF4MQJsw20v7/piigOP/xgdpC0tix8/z307w/VqhVPfeyViTtFuS1rcb92mdSgEHAvA9duvNtvvs5XaCULIYQQeRQbC7t2Fd++EFrDJ5/Ae3YT8CtVMtMmiztYUMlJ+O7Zgu/f20kr509qpaKpkAQMQgghSgytzXiBv/82F+jiGB+Qng5vvw32++bVrGmChSpVir4+9jxOHaHclrWolGRSgkOLdEtOCRiEEEKUCBkZZgOpw4fN6o3FsT11UpJZ5nnt2utpd99tAoji7H12S7yG3/bf8T6yj1T/SugKzotQXEq4yuJN35N2wI8Z/e8r8DpIwCCEEKLYpaebLojjx6Fy5eIZVHjxotlAaseO62ldupipk8U1hoKMDLyOHaDsn+vRKFKq3O705uw7GcuC35bzzdZ1JKWm4OPpzYQL7+Dv71+gVZGAQQghRLFKSzODG8+dK959ISZOdAwWoqLghReKp6UDwO3qJcr+sR7P00dJC6iM9vSyHUtLT2fNzv9jwW/L2XRwp8PzElOSmDdvHsOGDSvQ+kjAIIQQotikpMCff5qFmYprXwirESNMK8eFC6alITKymCqSno73oV2U3fE7GR5epFa53Xbo/NXLLN24ikW/f8+pi3FOT61X6TYG/iOSZ555psCrJQGDEEKIYpGUZDaRSkgo3n0hrEJDYcYMOHkSOncunjq4X4yj3Oa1lLl4jtTAKlDGrHi66/hBPvl1Ocv/XE9KmuO6P+5ubnRt3Ir+9/WgsXdl/O6sXSgrEEvAIIQQosglJJgFkdLSzJbQxeH0aedZDw0bmluRS0vFd++f+O7ZQoZvWVIr30ZKWiqr/viFhb+t4I/YvU5PCShbgcdaRhDZ+n5CKgYBkHxK1mEQQghxk7hyxewL4e5ePJtIaQ0ffQTz5sH778NddxV9Hex5nDtJ2S1rcb92hdSgqpy7dpnPVi3ms5hVnL183il/49tq079NDx4Ib4OXR9GNxpSAQQghRJGxbiLl5QV+fkV//rQ0mDIFvv7aPB4xAubOhbCwoq+LSk7Cd/dmfP/eTmo5fzYlXGHB6s/4YfsGUtMdN63wcC/D/U3upX+b7oRXr1csuyRLwCCEEKJIxMebboiyZc1GUkUtIQFGj4YNG66nNWhgVpMsUlrjeTKWsn+sIyUxkSXHDrBgwwr+OnbAKWulcv5Etu5G31YRVCpfsNMkc0sCBiGEEIXuzBn44w+ziZSXV/b5C1p8PLz0EuzZcz2te3ezSFNRLj3tlnAVv+2/c2HX//HB3j/5bMvPnL96ySnf3WH16X9fD7o2boVnmZKx1bsEDEIIIQrV8eNmnYXAQLMddFE7csSsp3DixPW0p56CIUOKcM2HjAw8j/zNnm/nMn/Lz6za9yfpGRkOWTzLeNDjrrY8cV8P7gitVUQVyzkJGIQQQhSaw4dh9+7i20Rqxw4YNsyMnQCzCNPIkfDII0VXh9S4U/w89y0+XfcVu8+ecDpepWIQ/Vp3o0/LrgSWrVB0FcslCRiEEEIUOK1h/35zK65NpNatg9deg+Rk89jLCyZPhrZti+b8J48d5Ou5b7Lsx8+5mHTN6XiLWo3o36YHnRvdQ5nieINySQIGIYQQBcq6iVRsbPFtIgVmnQVrsFCxolmUqVGjwj2n1prNm39m6cK3WR+zigytHY57e3jSq2kHoto8QP2qNQq3MgVMAgYhhBAFJj0ddu404wWKc18IgL59zWDLtWvhvffgttsK71xaa37+eRmzPxzPwUO7nY6HBlQmqs0D9G7RhQq+ZQuvIoVIAgYhhBAFIiPDBAsnT5odJ0uC55+HJ58s3K2p9+zZyvTpw/njj/VOx+6tG07/+3rQvkFT3N1KfrfDjUjAIIQQIt/sg4Xg4KI//5Ej8O67MGGCmbpp5eZWeMHCuXMnmTXrNZYvn4+263rw8/LhoWYdiGrTg1qVQwvn5MVAAgYhhBD5kpFhdnk8frzoWxa0huXL4V//gsREePNN8+/C7ApJSkpg4cJ3mD9/KomJ1wczlnFzp9+9DzC0a18q+pUrvAoUEwkYhBBC5FlGhpk2efx40bcsXL0Kb70Fq1dfT/vtNzh4EGrXLvjzZWRksGrVYmbOHMWZM8cdjnW8owUjew6kZnDxtChobQZ4piaDXyEFSxIwCCGEyBNrsHD0aNEPcPzrLzNl8uTJ62lhYSaAKIxgYceOGKZNG8bOnf/nkF4vpDqjez3FvXXDC/6kN6A1pKaaLcK1Nu99uXIQcjv4Vcn++XkhAYMQQohc09pMnSzqYCE9HebPhw8/NP+2evBBs5FUQe9RcerUEd57bxSrV3/mkB5YtgIvdetHn3u6FNlgxpQU04pgXSCybFmoVs3ce3uDe+JVSMyAQlp6WwIGIYQQuaK12ZPhyJGiDRbOnoXx42HLlutpZcvC2LHQuXPBnish4Srz5k3h00/fITk5yZbu6V6Gge16MaRzH8p5+xbsSTNJTTUBgjUw8vWFkJDrm3fZ1npKSIDPv4VvvoF+/aBOnUKpjwQMQgghcszasmBdlKmogoWTJ6F//+tLPAPceSdMmmQuogUlPT2d5cvnM2vWa8THn3Y4dv+drXil55PcHlg4bf5paSZASLPsbO3tbVbJLFfOBAhO+3AkJ8PKlfDll3D5skn7+muzAEUhkIBBCCFEjmgN+/bBoUNmNkRRjlkICYG77jLLPbu5mc2jnnqqYPen2LJlHdOmDePvv7c5pDeqVpMxDw2mec07Cu5kmJaDpKTrAYKnJwQEmGmhLgMEq9RU+PFHWLoUzp93PObmZppiCoEEDEIIIbJVnMECmPONHWu2qR46FJo2Lbiyjx8/yLvvvsLatV85pFcuV5Hh3QfwYLMOuBXA+tYZGaZRICXFPC5TBir6Q4XyJkDI0bbfO3eaBSfOnHFMDw6GyEho3BiqVs13XV2RgEEIIcQNWTeSOniwaLohtIZVq6BDB8eLaMWKMHduwZ3/6tVL/Pe/b7J48bukpaXa0r3LeDCo3YMM7twHXy/vPJefoSEl+fp+Fu7upvWgQgXw8zMtCrl+LRUrwrlz1x8HBMCjj0KXLqZJIj4+z/XNjgQMQgghsmQNFg4cKJqNpC5dMosvrV1rfjCPGOF4vCCChbS0NL7++iM++GA8Fy/GORx78M7WDO/1FCH+lfJUdno6XLtm3jfrKpMhIWbAord3LuuvtSnQvt8lNNREUps2Qe/e8MADOWyayD8JGIQQQmTpwAETMBRFsPDHHzBu3PXW9sWLoU0buOeegjtHTMwqpk8fwaFDuxzS765WgzEPD6FJjQZ5LjshwQwvCA01AxW9vfPxnm3fDp9+Cnff7TyIceBAGDzYRCFFSAIGIYQQLh08CH//XfjBQloa/Pe/5mZdYwCgTx9o0qRgznH48B5mzHiZDRtWOqSHlg/glW796NaiMyqPzRcZGWaSQtmyULeuCRTybO9eWLgQduwwj48ehe7dTQRiZb9ZRhGSgEEIIYSTgwfNtauwg4XTp81gxm12ExMqVDAtDe3b57/8ixfjmT07mi+//F/S7VZ68vP04rnW3XgiIhJvr7yv9pSUZMYohN4GwZXy8V4dOmRaFDZvdkxPSTHLaRZkM0seScAghBDCwaFDJliolJ8LYA789JMZr3DlyvW0u++GN97I/yZWqakpLF36PnPmvM6VKxdt6UopHm3Ukpd6PklQUN7XU8jQcOWymd3QoCH45jXmOH4cFi0ym2DYc3ODTp1Md0SlvI2nKGgSMAghhLA5fNis4lipkt1KggUsORnefhu+spvF6O5uuuWffDJ/59Vas379d7z77sscPbrf4Vir6vUY27UvdRvkb05mcrLZGbNqVRPY5Km+Z87AZ5+Z0Z32/TBKwX33weOPF9r0yLySgEEIIQRgVm/cvbtwgwUwZf/99/XHISGm+UF41QAAIABJREFUpSG/4xWOHz/EW289zaZNPzmkVw+swmvtetH+ni4oD888l6+1aQ3x9IT69c2YhTzbutU0sdhr2dIECmFh+Si48EjAIIQQgiNHYNeuwg8WwMwSnDTJbHvQujWMGeM4pi8v1qz5gjfeeIpr1y7b0sr5lOXFll3od98/KFMxMF/lp6SY6ZKVK5sNn/L9HnXpAsuWmVUZ77oLoqIKbQ+IgiIBgxBC3OKOHjULCAYFFU6wYJ1BYD8eIjTUjPGrVi1/ayskJycxY8YIPv98li3N3c2dyLvbMezeByh/Wy3Ix26SWsPVq6bu9eqZdRVy5do1sylU06amACsPD3j2WbOGQqNGea5fUZKAQQghbmHHjsFff5lgoSD3ZbDavNnMeIiKMjd7oaH5K/vo0f2MGvWow94PoYFVmPlAFI0aNEPnczfJ1FQTLAQFmbpmubeDK0lJsGKF2Rjq6lUzMOSNNxzzFOT61kVAAgYhhLhFHT9upvsXRrCQlgYffgjz5plf6TNnmutjg7yvi+Tghx8W89ZbT5OQcNWWdn+9u5nacwC+IWHofC4JefWqqXftOuBfMRdPTE0161ovXQoXr8/OYPt2M0CkYcN81as4FfK6Xc6UUg8qpXYopZKVUoeVUsNd5FFKqTFKqWNKqUSl1HqlVLiLfA2VUj8ppRKUUieVUq8rpdzzUpYQQtxKTpww17DCCBaOH4dBg+Djj81FF0yXhP30ybxKSkpk0qSnGTv2cVuw4OFehtc79+G9J0fjW7VGvvo40tLMdb5cObjjjlwEC2lpsHo1DBkCs2c7BgtVqsDw4Y5dEqVQkbYwKKXuBZYBc4GXgXuAqUqpDK31DLuso4BxwCvAXmA4sEYp1UhrfdpSlj+wBtgN9AJqAe9ggqCxuSlLCCFuJSdPmoWSCiNY+OEHmDzZdN1btWgBr79uzpcfhw/vYdSoRzl4cKctLaxiJd7r+zz16+b/d+C1a2brhho1zJ5OOYo70tPh11/NOtanTjkeCwqCxx4z6ykURn9PESvqVzAe2KC1HmR5vFopVREYr5SapbVOUUp5Yy7yk7XWMwGUUjFALDCU68HAEMAHeFhrfRn4USlVHohWSv1La305F2UJIcQt4dSpwgkW0tJg2jTTEm/l7g7PPQf9++d/Aajly+czZcpzJCUl2NJ6NmjGxH7DKOebvykW6emm9aN8Bah+ey73cpozx4xVsFehglnX+v77zRzMm0Suvy5KqZrAo8DtQOYVs7XW+qkbPD0ceD9T2mrMr/5WwC9Aa6A8YPvaaa2vKaW+A7px/SLfDVhlCRasPgOmAu2A73JRlhBC3PROnTIbPBV0sHDxIowcaZYWsAoNNWsr5HcCQGLiNaZO/SfLl8+3pXmV8WBcz4E82qZHnvd/uF6+mTJZvbp5X3JdXNeu1wOGsmXh4YehR498bihRMuXqK6OUehBz8XUDzgLJmbLobIrwBlIypVkfN8AEDPWBdGB/pnx7gMfsHtcHfnY4udZHlVIJlmPf5aIsIYS4qZ05A3/+CYGBBRssaG265617JQF07Ajjx+dzYSPgwIGdjB79KIcP77Gl1QyswrsDR1O/Wo18lW3dMMrPD2rXNks8Z+uvv8zuUvZNEDVqmKAhIAB69TIF3qRy+7V54//Zu+/wqKqtgcO/NekhBAhdmvSioghX7BKK2L0qCIIoiqIo6LWAV0FUrt2LfgoogsK1YMUONkQpioB0kF4SegmEkp5M9vfHnkkmBZJJJjMp632ePMA+Z072IOas2WfttYB5wEBjzKESfL+twD/yjZ3n+jXG9WstIMkY48x3XiIQKSKhxpgM13lHKSjRdcybaymlVKV18KD99B8T4+XWwGIQgX/9C+65x24QuPdeGDKkdLUVjDF88827vPLKCNLT03LG/3nupTzd936qlaJZFHg0jGpczOZaGzfaohGrV9va1TfckPf48OGlmk9F4W3A0AJ4pITBAsBkYLKI3A3MxAYL7l0S2Sd9lR+JyFBgKEDTpk0DPBullCqdgwdh2bKyCRbcOnaE0aPtzoLLLivdtZKTT/DCC/fy448f5YxFhITy1E3DuOm8HqW6trthVHi43d4ZWVSZhu3bbavpZctyx2bOhN69i/HiysfbgGEjUJr6mtOAs4G3gClACvAYMAFw71hIBKJEJCjfykAtIMVjRSARKKwpeC3XMW+ulcMYM8U1N7p06VLUIxallCq3Dh2yKwu1avkuWEhKspUh85cTuOaa0l9706ZVPP74zXmaRrWu34TXb3+M1g1K9wEuIwNSUuwOx4YNi6houWuX7SD5xx95xx0O2+8hK6tUc6movA0YRgH/JyJLjDHbvf1mrpv2cBF5EmgM7MDmGQAsdv26EQgCWgGbPF7eznUMj/PaefwZEWkCRHqcV9xrKaVUpZKQYD8Y16zpu0T9+Hh45BE4cgTef7/0lRrdjDF88cVkXh3/EBmZualxN3e9nDE33E1EqDfbFvJf2wY5QUHFaBi1f7/dHjl/fsEOkpdeCrfcUu46SPqTtwHD09gVhg0isgU4ku+4McYUuSBljEnEtQogIvcBi4wx7hv4IuA40Bd41nVOJHAtrk/+Lj8AI0WkujHGXQ6kH5CKTZ705lpKKVVpJCTA0qW+DRb++MM+dkhyFVZ85BG7Wl/alYukpGM8+5+7+GXuzJyxamHhjOtzP9d1Lt3zDc+GUaedVkSy5/vv237bznwpb+efb7tkNWtWqrlUBt4GDE7yflL3ioicD1wMrMJud7wF6O0aA8AYkyYiLwJPikgiucWWHNhHF26TgQeAL0XkJWx+xdPAq+6tll5cSymlKjxj7CrA+vU2WPCqnsAprvneezBpUm7VxrAwm/tX2mBh/fplPPFYH3bvi88Za3dac964bRTN6zUq1ZzdDaPatLFlEYoUFZU3WDj3XBsolPMOkv7kVcBgjOlWyu+XiV0FeBqb5LgQuMgYszbfeS9ib+qPY1c0lgG9jDEHPOaSKCI9gInYLZRHgddc1/bqWkopVdGlpdmOkwcP+m7rZFqardD488+5Y/Xrw/jxdnm/pIwxfPLheF6f+DhZztx8gAEXXskT1w8hLKTkyyLuVYVTNoxKTy8YTV19te0q2aiR7ZJVgXs+lBW/Vno0xiyn4LbKws4zwHOur1Odtx7o7otrKaVURZWQYKs3gr2h+8L+/faxwyaPNeVOneCll+yOi5I6fuwIzz5xC78uyY1CosIjee7m4Vx1zsWneOWpZWXZQCEszC4K1CysB0RKCnz7rQ0Mnn/e1lBwCwuzpSqLXRO66ilJpceGwCPYaoox2DyG37CPArQ3g1JK+YnTCVu3wrZtdtndV8UFV6ywlRsTE3PH+vSxAURpHkOsWzKX0WMGsCfxYM7YmY1b8X+3jaRZnYYluqY7UAgOhtNPtztCCtRVSE/PbTXt7oD10Uc2KcNT7dJsAgwwp9P+B0tOLrMtn95WemyDfYxQC/gDW4ipAfAgcJuIXGKMyV9VUSmllI8lJ9s6QsePQ926pe/V4LZzJwwblvs4PzgYRo2yFY9LKjsjg08n/ZvXP36dLI/dB7dfci0jrx1MWLD3UYjTmbv7oWlTuzBQYKuku9X055/njX7AtutMSan49RRSUnLfW+vW9lGKr5aZ8vF2heEl7K6DrsaYOPegiDTD9oR4CSjFPyullFJF2bvXVikODbXBgi81bWobLH70kb0Jv/wynFOKRpAntq9n3JO38tumlTlj0RHVeLH/g/Q663yvr5dtIOmEfWrQqNFJ+mI4nTB3Lnz6qS1G4alePbs9slu3IooxlGNOp93bmpZml5YuucQ+Xinj4MfbgCEWuNczWAAwxsSLyNPAmz6al1JKqXwyM21OQXy8XT0vq8qNDzxgdxoMHGgLHZWEpKex/pt3GfnmGPYn5VbxP7tpG/7vtpE0jvHuU7B750N2tv0AXb9+Ie/fGFiwwEY7+VtNx8TYSKhnz7L7iytrKSk2UHBv/2jf3v5F+CnnwtuAIRQ4cZJjJ1zHlVJK+dixYzaxMT3dt/eIjRvtbgLPgkbBwTZfoUSMIXjXNj6c8Bivzfsap8l9BDGk2w08cvUgQoKKf+sxxt4ns7KgTl1oUP8U20VFYM6cvMFCjRo2AeOKK3yzz9TfPFcTYmIgNtbWhChWtyzf8jZgWAWMEJEfjMn9VyC2v+h9ruNKKaV8xBibV7B+vW2E6Mu8vNmz4bnnoGtXu1WytHkQQSeOkrZwFg9Ne46FcbnFdGtVq85Lt/yL2A5FbpLLIzXVBkgxMbacc7HukQMH2taZ1arZJlHXXhuQm2upJSXZvuEOh11JaNfOPn8K4A4ObwOGccAsbKXHT4F92KTHvkBr4GrfTk8ppaqu9HRbW+HAAd+2pc7KgokTbaVGgIULYdo0uOuukl4wk4gta9k+93OGzHyLg0nHcg51ad6BVwc9SsOadYp9ubQ0+xVdA1q0OEnH6DVr4Mcf4aGH8j5iaN8eHnzQRkGl7a/tb1lZcPiwffYUE2P7hDdr5rvtL6XkbeGmH0XkGmyZ5dGAAAZYDlxjjPn5VK9XSilVPIcP+762AthHG088AUuW5I41bw6XX16y64Uc3EPU8nks2bCMu794m+T01Jxj9/bow4NXDCS4mMmFGRl2VSEyEtq2td0vC9i40UY6a9bYP595Jlx1Vd5zepSuq6XfuVcTgoNtwNO2rc3mLGf1ILyOV40xPwI/unoy1AISjTEpPp+ZUkpVQU6nrauwZYstPuTLD5fbttnchN27c8cuvdRWc/T2w7gjJYnItYsJj9/E97u38eCnE8l0VW2sERHFa4Me5ZJ25xbrWpmZdptoeDi0bGnTDgrcK7dtgxkz8raaBtv/4YorfLev1F/cqwkZGXbnRs+edotKOc6zKPEClytI0EBBKaV8JCXF1lY4etTeQ3x5D5w3D8aOtd/D7e677ZdX38fpJCxuI1Gr/wCHg4/iNjHm80lku9La6kfHMP3eccVqR+0uuhQSkhsoFJhLfLzd9fDnn3nHHQ57k7355ooVLJw4YZd5QkJszYQ2bezzpnK2mlCYIgMGERkLvGOM2ev6/akYY8x/fDM1pZSqOvbvt8FCaKgNFnwlOxveeQemePTnjYiAZ56xj8i9EXzkIFHL5xN89BBZMfWZsnAWL8/6X87x0+uexvR7nilyy6S76FJwsH1EX6tWISUR9u61raYXLMjtegX2xnrZZdC/f8VpNZ2ZaXc6ZGTY50uXXw5NmviulaifFGeF4WngR2AvBRs75WcADRiUUqqYyrq2wvvv5w0WGjWyOyJatSr+NSQ9jcgNy4jYvAZnVDQZ9Rrzyqz3mPrblznndGjUgmlDn6Z29cKaOFjZ2XZFAaBRY6hzskTO9HR49NHcXtpuF14IAwbYpfuK4Phx+xUSYnMt2rQpXSOOACsyYDDGOAr7vVJKqdI5ftwmNqamll39nRtvtL2Wdu2C886zPZcKbcxUGGMI3b2N6isXQmYmmfUakWUMYz+byOdL5uScdl7LM5k8ZAzVwwuvNJhtINlVdKlhQ7s78JSBUVgYXHMNfPKJ/fM//mEDhZYtiznxAMrMtN3AsrLsm73gAruaUFGLRXnwtpdEU2CfMSazkGPBwGnGmJ2+mpxSSlVGxtgb+Lp1NtmwTvF3HHotOtquKMyaBffdV/ytmUHHE6m2ciGhB3aTVasupmY46ZkZPPThf5mzdnHOeT3P7Mr/DRpZaEtqz6JLdevaoKhATt+xY7bO9cX5OlVefz3ExdmIpzS9tP3FvZoQGmprabdubZ+1VCLeJj3uAC4AlhZy7GzXeAUtzq2UUmUvPR3+/tvmLPiytgLYG/Pq1dC5c97xFi1suefiXSSTyM2rifz7L7LDIshs0ASAE2kp3DftORZvXZtz6o3/6MFzNw8vdNtkSkpuOYGGDQvZ7ZGUBF9/bdtNZ2TYvZ2NGuUer1bN7v8szzIzc+smVLLVhMJ4+0/1VAtmIUD2KY4rpVSVduSIfQTh7ofgK8bAb7/Bm2/aD+XPP1+yugoh+3dRffk8HKnJZNZpAK4SzoeTjnHXlGdYt3trzrlDuv2TUdcMxpFvh0JGhg0WatSweRIF+iGlpMB339lgwZ3QADbB8dFHvZ90IHjmJpx1VoXPTSiu4uySqAl4/k00EpEW+U6LAG4H9vtwbkopVSlkZ+fWVqhRw7e1FZYtgwkT7KqF2zPP2Mf9xX3k70g+QbW1iwnbuRlnjdpk1sv9pL838RB3vD2W7Qf35Iw9evXtDO1+I+KRdOHe+RAWdpKiS+np8P33MHOm3VroqVmzgo8kyhvPugn168P559vky0q6mlCY4qwwPAg8hd0BYYCZJzlPXOcppZRySUmxj+gPH7bP8X1VMmDDBpg0CRYvzjterRoMGmRXxovkdBK+fT3V1vwJwUFk1m+SJ/Ny64Fd3PH2U+w/mgCAQxyM6zOMfhf0znOZ5GQbMDRqDHXr5NsimZkJP/0En38OiYl5v3+jRrbV9MUXl99aCvnrJrRrVyVWEwpTnIDhayAOGxBMw5aF3pbvnHRgvTFmjU9np5RSFdiBAzanIDjYd48g4uNh8mTblNFTaKitYTR4cPF2QQQfPkDUivkEH00gs3YDCM77SXl1/GbufucZEpPtakBIUDCv3vooV5x9Yc457scPtWrZjpcFEhoPH4aRI+2uAU/16tlAoVu3QgowlAP5qzBW0LoJvlacbZWrgdUAImKAWcaYw2U9MaWUqqiysmDzZtixw34Y9dV95tVX4dNP7ad5N4fD7kAcOhQaNCj6GpKWamsqbFmDM6qmXVXI54/Nq7hv2vOkZKQBUC0sgjfveIIL25wN2G2SSSdsINSq1UlKOYN98zExuQFDTAz062crNJbHpfykJLuaEBQEZ5xhn634sj1oBedt86n3ymoiSilVGaSmwooV9t7j69oKERF5g4XYWLtVsnnzYrw4O5uwXVuJWvU7OJ1k1mtc6GOAH1b/wSMfjs/pC1GrWnXeuftpOjZtDdj3l5FhNwXUr++xQJCdbR85eN5gReDWW+2+zj59bM+H8tYrwXM1oW5d27iqnPd0CBSvN/SIyBnAXUBbIH/qjjHGVLA2YUop5RvHjtkkRIej9LUVnM6Cq/WDBtmcwdatYfhwWzywOIKOHSFq5UJCDu0hq2ZdTFjhWZef/PkjY2e+hXGVYm5Qsw7T73mGVvWbkJVlg6Dq1aFlK4iMcL3IGPjrL9sYKjPTZmB6Tvzss21t6vJ2A3Z3iMy/mlABejoEireFm7oC87E5Da2BNdiOlU2B3cDWk75YKaUqsQMH7MpC9ep2JaCksrLsjsPp0+Htt21ugFtUlO3DVK9e8e5rkplBxObVRK7/i+yIaoU+fgAwxvD23JmM//6DnLEW9Rox/Z5xNKxZl+PHbRDUvLl9qiCCDRRWrbKBwubNuRdbsMAufeRMQspPsOB02tWE9HQbHPToYXdolJf5lXPerjA8D3wJDAIygSHGmBUi0h34AJsQqZRSVYYxNhHx779Ll6+QnQ0//2wTGt3tpydPhmfz/VQtVvKkMYTs30n1FfNxpKWSWee0kyYXZmdn89J305k2/5ucsbOatGLq3U9RLbgGx47ZAKVhQ4+0g3Xr4MMPYf36vBcLC7P1Ccqb5GT7uMThyN3pUKeOriZ4yduAoSO23oK7dVgQgDHmVxF5FngB6Oq76SmlVPnldNrGUTt22MffJUn4NwYWLbJbJD0/qAOsXGlXzqOiin89R9Jxqq39k/BdW8msWQdn9ZOXJ85yOnniswl89devOWMXtO7IhNuewGRG4giF9u3tVk0ANm60KwqrV+e9UEgIXHkl3HRT+SmH7HTaSllpaTaS697drib4sghGFeNtwBAKJBtjskXkCNDQ49gmoJhP1JRSqmLLyIA1a+DQoZInN65ebQOFFSvyjkdHwx13QN++XtzfsrII376OamuXQHAIGflqKuSXlpHOvz54hbl/51b673XW+fznn4/iyA6lcTO7au9wYOtYT5liEzQ8BQdDr152P2d52U2QkmJXE0TsSkL79jaa09WEUvM2YNgKuEuArQHuFJFZrj/fgVZ6VEpVASkp9iafmmqX6721dast47xgQd7x8HDblHHQoEIqJZ5CyME9RC2fhyP5OFkx9QvUVMjvRGoy9057jqXb1uWM3dSlF4/0vo96dYJo1Cjfo5XQUBsduTkc9hN7v36+rXFdUtnZdjUhNdUWoejWDU4/XVcTfMzbgOE7oBvwETafYTZwHHACUUBx25sopVSFdPSo/aAdFFSygn/GwH/+k7eUc1CQbco4ZIh3uyscySeIXLeE8PhNOKNjyKrXuMjXJJxI5M4pT7Nhz46csdsvuolHrrqNZs2E6GjXJD1bB8XEwFVX2T7Zl15qiy6ddlrxJ1pW0tJsoGCM3Tpyxhll1ydceV2H4WmP3/8iIucDNwGRwI/GmJ99Oz2llCo/fLETQgSGDbPbIkVsaYJ77sm7G6JIWVmE79hgSzoHOQqUdD6Z3UcOMHjyWOIT9uWMPdBzMI/ffKPNwTi0H6Z/apMW7ror74v79LEFl5o29WKiZSA720Ztyck2uePCC207zpxEC1VWStVY1RizEljpo7kopVS5ZIxNbNywwT6qL26RwqQk2725X7+8CZFdu9rVhB49bKNDbwQn7KP68vkEHU8ks3bRjx/cNu+L5863n+LA8SOA7QvxbN/7eahPL8KTEuDtz2y9aafT5iZcf7199u8WHW2/AiU93W6JzM62jxs6drRbN8prD4pKyIed2JVSqvJxOu3mgLi44u+ESE+3vZamT7fFnGrWtCv6bu5VBm84UpOJXLeEiB0byKpek8z6xV+SWBm3kbunjuNYahIAocEhvHv/SAZ2aot8MBV+/NEWXXLLyrJbN66/3rtJ+pox9i/wxAm7pHPeebYWtTcJHspnitPeege52yiLYowxxWyoqpRS5Zt7J0RCQvEejRsDs2fDW2/ZxxdukyfbzQQlap/gdBIWv4mo1X+ACBknKel8Mgs3ruD+/71AakY6AFHhEXz1r4fouXkjvPVf+yY9dehgyzkXt4xkWcjMtKsJmZn2Wc1ll9mciWD9jBtIxfnbn0/xAwallKoUUlJscmN6et6V+ZPZu9cWWVq6NO/4aafZHIWSrJx7dpTMiqmPCfGuKtTslQsZ+dFrOX0h6kRV58cLzqfz66/ZHQWe2rSBgQPhnHMClzR4/Lj9Cg2182jdunitN5VfFKdb5WA/zEMppcqNo0dte4SQkKJ3QmRn28cPEyfmvQfHxNg8hRtu8L76o6SlUO3vvwjftg5nVI2TlnQ+lRl/fM8zX76d0xeiad26/HzddbR99928JzZvblcUunQJTKDg2fypfn244ALbSro8drOs4nR9RymlPOzbZyss1qhR9Db++Hi7RXLVqtwxh8PWUhg6FCIjvfzm2dmE7dxC1KqFYLJP2lHyVIwxvP7Dp0z65aOcsfZNmvDzM8/QOCYGfvnFTrxJE7uicP75gUkcdJdrDg7OLddcXoo/qUJ523zqtqLOMca8X/LpKKVUYBgD27fbBMfi7oR47728wUKLFjB2bMke/wcnHiJqxQKCD++3jx9CvWuIlJEB+w4fZ/JvM5i57Iec8fPatOH7sWOp7d7hMGSITSS8+OKS1bIujcLKNZ9+ujZ/qiC8XWH430nGPXMcNGBQSlUoTqfdMrlzp3c9IR54AH7/3d5/Bw+292LvHz+kErlxORGbV5NdLZrMBsWvc5CRYe+9m/bt4Ivls5i9ej5pmblJjL0iIvhy3DiiPJc6zjnHuwn6gme55rZt7YqClmuucLwNGJoXMlYbuAYYANxa6hkppZQfZWTYVYIjR07dNjoz0z5u9yzYVLMmjBtn+y21bevlN87OJmzXVqJW/Q5ZWWTWawSOoiMVd5CQmZXFH9sX8/lfs1i8ZX2B8/oCH6SmErZmjX3s4G/Z2TZISEmxz3cuucQuwZSm97cKKG8rPcYXMhwPrBARAR7GBg5KKVXuJSfD8uX2JnyqnRDr19vAoGNHeOKJvMdKci8OOnqYqJULCEnYR1bNupiwUydLZGbahEpjIMV5lNnrfuKDBT+y98jhAud2Ah4EBkVH4+jbFzp18n6CpZGebvehGmNrJpx5pk1m1AJLFZ4vkx4XYgMGpZQq9xIT7bbJ0NCTd2ROS4OpU+GDD+wH5q1b4fLL7YaCkpD0NCI2rSRy00qyI6JOufvBM0gIC4ODGVv44PdZzFy0kIysrDznBgN9gOHAhVFRyE03wdVX+6/5kmeBpWrVbBTVqpV3fblVuefLgOF8IMmH11NKqTKxd69tLR0dffJ76qpVdlVh587csbAw2+nZa8YQuns7UasWIJmZZNY9rdDHD1lZNkjIzrbfq079TOb+/Qdv/zyLJZs3Fzi/HnAvcA9wWmQk/POfcN11JdieUUKFFVhq1Mj/yZTKL7zdJTG2kOFQ4EzgamCiLyallFJlwXMnRJ06hRcOTEmBSZPgs89cTRtdunSBMWO8bBIFBB07QrVVvxN6YBdZtepiauR9hu8ZJISG2vYIyc7DvDfvJ6b89CMHjh4tcM2uzZszIi6OPsYQFhEB115rgwV/faJPSrLFKkJC7HOatm21wFIV4O0Kw9OFjKVj8xieA14o7YSUUqosOJ02F2HnTpvcWNgj9SVL4Lnn7AqEW7Vq8OCD9n7szWN4ycwgYtMqIjcsJzs8Is/uB88gISQEGjSA6BqGVXEbee6DWXyxaBFZTmee64UGB9PvkksYcc01/KN1a3j3XTuhm27yT1Mozy2RdevaWtdNm3q/LURVWN4mPWrWilKqwklPt48YEhML7wmRng6vvAJff513/MILbZJjgwZefDNjCN0XT9Ty+Tgy0sis0xCCgnA67eqFO0ioX99+KJegDD5ZuICJs2ezYtu2Apc7DRjWoQND//1v6nl+ih8yxIt8Nq2WAAAgAElEQVRJlUJqqv2LMwbat7dbIuvU0S2RVZDfKz2KSH9gFNAGOAbMBf5tjNnrcU5N4FXgn9hHHguBEcaYrR7nDAamF/IthhljJnucJ8DjwDCgDvAX8IAxZlUhr1VKVTJJSXYnRFbWyXdChITA7t25f46OhkcesR0mvbkvBh1PpNrqPwjbG0dmrbpkRMWQmprbMbpePRskREbCroRD/Ofz75n6888cPnGiwLUuBkYANwAhhQQSZcoYGyQkJ9u/jIsv1i2RyvuAQUSCgNuAC4BGwB5gEfCBMcZZxGuvAz4GJgEjgYbAs8BsEelsjMl2nfopNi/iQWxQMQaYKyJnGWOO57tsd8Czi8r2fMf/DTzp+n4bsTs5fhGRM40xJUlfUkpVEEeO2GAhNPTUj9gdDpuf0L+/bWXw2GP2Q3RxSUY6EZtXE7lhOc6wCBKjm5GVDsFOe51atdx5iIb569YxYdYsvl6yhOzs7DzXCRdhoDEMB3LKKzkctoaB8UMPwIwMm8TodNoKjB072qQK3RKp8D7psRnwE3Z1YDdwADgLuAt4TESuOEmtBrcBwApjzHCPax4HvgHaAhtE5ALgcqCnMWau65wlwA5gKPDffNf8yxhT6O4MEQnHBgwvGGMmusb+BOKwO5DGFP/dK6Uqkj177E6I/D0hjIE//rA7/zyTHhs3hk8+8TKp0RhC9+wgauUCTGo6h8MaYiSIOrWgdowNEhwOSE5LY+pP85gwezbr4gv+iGwWFsZ96ekMMYacbgoidtdB//625WVZcneJDAuDzp1tl0h/5EWoCsXbFYaJQDRwsTFmkXtQRC4CPgcmANed4vUh2BUDT+4UYPfC3zlAJjDPfYIx5oCIrMbuxMgfMJzKha75fuZxrWQR+Q64Eg0YlKqU3DkL+XdCJCTASy/Bb7/BvffCXXflfZ03wYJ790Pw3l0cDakD1WrToL7tQ+FujbB9/37e/P573p0zh6PJyQWu0b1uXUYkJHBtejp5NiJedBHccotNKiwrnl0iGza0CRtNmhS+dUQpvA8YugP3eQYLAMaYP0TkCYreVjkN+NrVxOproAH2kcSvxhh3bdNwwFnI440MoH0h19wmIrWBbcCrxpi3PY61A5zAlnyv2QD0K2KuSqkKyum0j9vd9z5jYPZsGD/e1hYCu8kgNhZatvTu2u7iSyF/ryLFRGBqNaVRQ9tLKTjYdov8eeUqJsyaxexly3LaS7tFhoVxW2wsw6+5hjM2brR9sd26drWtLpsXVoXfR1JS7LOaoCA44wztEqmKzduAIQk4eJJjB4GUU73YGDPblaz4LvCea3gReVcltgLhrnyFtQAiEoHNaajucd4+bG7CUiAI6A9MFpFIY8xrrnNqAUmFBB+JQKSIhBpjMjwPiMhQ7KMPmpZldK+U8ov9++H552HRorzj11xjdyoUW3Y2oXu2E7J4IZnJmTjrN6RJ0yCio3PrFB1ITOS6555jaSFFllrWq8f9117LHT16UNNdL+G002DmTFvsaMAA+yigLHj2dahZU7tEqhLxNmD4EFtY7IdCjt1DEZ0qRSQWmAy87rpGfWxth69EpKfrxv4TNl/hbRG5AzgOvAjUAHLqoRpjfnKd6/aDK2dhjIi87pFA6RVjzBRgCkCXLl38kGWklCoL2dn2XvzGG/Y+6daoEYweDeedV/xrBSUm4Fj0O479e5BG9Wh2VjjVq+fdQbE/MZHuY8awYdeuPK+94qyzGBERwRUrVuBo0SJvcaXgYHj11bIruOTZ16FNG7uiUNi+UqWKwduAYSvQV0TWAl9gkx7rY8uYV8fetO90n2yMmZbv9eOBb40xj7kHRGQVdvfC9cCXxpgM19bLj13jAL9jg5HuRcxvJnAzcDp2t0QiECUiQflWGWoBKflXF5RSlUNaGjz6qE16dBOx+YP33Vf83YEmNY3sZcsJ27yKqAbVqXtpU6pVK3jeviNH6D5mDBtdezMdDgf39+zJ8Oxs2ixYYPMEAGbMgBdeyHvD9nWwUFhfh9atKXTiSnnB24BhkuvXxsAZhRx/0+P3Bpuz4KkdNhDIPcmYTSKSCrT0GFsqIq2wuzGyjDHbRGQWsLiI+Zl8v27EPq5oBWzKN4+NKKUqnWPHYN26vLsQmzWDJ5+Ec845+es8ZWVkk7FhKzXX/k7tGlnEXNaI8MjC+yPsO3KE2NGj2bRnDwBBDgczzjuPfgsX2qJHnjIzbW2DslhRyMqyqwlZWXYZRfs6KB/zNmAobSZOPHCu54CItAcisFsdcxibKbTJdU5roCdwbRHX7wMkuL4P2PyI49jW8M+6rhXpus6Ukr8NpVR5Vb26fTSflmbvlYMGwd13F+9xfUYGpMQfovb6BbQIOkDNf9QlpPrJOz7uPXyY2DFj2OwOFkT4OCSEvovzfbZp3hwGDoR//MP3jwOSk21+QnAwnHWW7etwsvabSpWCt6WhT1VjoTgmA6+JyF5ycxjGYoOF790niciT2BWABGydhyeBT4wxczzO+QKb8LgGu4rQz/X1gDt/wRiTJiIvAk+KSCK5hZsc2C2gSqlKxuGw9+fMTHjqKVvNuCipqZB8KIW6O5fTNmktNVpEEVT75K2nAfYcPkzs6NFscTWeCAY+MYab0tNzT2ra1CYznn++b4sfeSYxxsRoEqPyixJtuBWRM4HLgBjgCDDPGPN3MV76BnZ75DBs8uRRbH7C48YYz03KtYH/w5Zy3oWtvTA+37U2AXcCTbA1HNYDtxljPsh33ovYAOFx13WXAb2MMQeK9WaVUhVOVJRNFSiqgWJSEiSfyKb+8c2cte93qlfLxnF6oyJv7rsTEogdPZqt+/YBECzCp8Zwo/uE006zdRQuvti3jwQyMuxjh+xsaNXKrihoEqPyE8m/R/iUJ4sEA/8DbiG30BLYnIGPgMFFlYeuSLp06WKWLVsW6GkopbzQrVs3UlLgxRfnUaNGwePunMCMDGjoOEDLPQuISj2I1K9frE/ouw4eJHbMGLbtt5Xlg4OC+OzOO7nhnXfszbtfP+jWzbeBwvHjdtLh4dCpk01i9Fcra1XliMhyY0yX/OPerjA8hd2FMBa7xXI/tvjSra5j212/KqVUueJ0wtGj9tcmMck0P7KMalvX2drRdYtRcyUzk51ffUXsJ5+wPcvu8A4JDubzUaO4/vzz7eOHM87wXaVEp9OuJmglRlVOePsv71bgWWPMcx5j8cBzrqZUd6ABg1KqHMnMtB/ORaBZYyfN0jYRvmKRXWpo3Ljo3IKsLPj1V+I//pjYw4fZ4RoOCQ5m5mOPcV3Xrnbg7LN9M+HUVFuJUcS2km7f3rtOWEqVEW8DhtOwOw8KswgYXbrpKKWUb2RkwMGDtnV1u3bQUPYT+ud82z+hXr2iHz84nbBgAXz8MXH79xNL7lauUOCLxx7jGnewUFrG2OWPpCTb9OmSS2zN6vCT79BQyt+8DRj2AhcBvxRy7ELXcaWUCigRW5ypTRuoVy2ZoGVLYMMGmwXZ5NS7H8jOht9/h48/hj17iAO6kbtXO9Th4MtRo7jaF8FCZqZ97OBuJ92jh7aTVuWWtwHDDGC0iGS7fr8Pm8PQH7u68JJvp6eUUt4LD4dLLnTi2LgevvvT3oCLevxgDPz5pw0UXC2od2CDhZ2uU8JCQvjqiSe4snPn0k0wKcmuKISE2CTGNm0oNENTqXLE24DhaaAF8Izr926CreA4ziezUkqpUpCsTBxffG5zAerXh9DQol+UkQGTJ9sbOTaDO1aEna6dZGEhIXz9xBNcUdJgwem080lLszkJvXrZEpQhISW7nlJ+5m3AUBNb++A54FJy6zAsKGYdBqWUKnspKbaXQlGPHzyFhUGfPvDOO2wLDSU2KIhdrtLOYSEhfDN6NL3PPbeIixQiPd3mTbgbQJ15JtStq7UTVIVTZMDg2v3wJPAgEA04ge+AIcaYo2U7PaWUKqFTfXJfvRrWroVbb807fsUVbN23j9g//2T3kSMAhIeG8s3o0VzeqVPxv3f+BlAXXGCTGLUBlKrAirPCcC+27sI84C/sI4kbsD0a7iizmSmllK/9/Td89JENFgC6drVFkFy2JCQQu3gxezyChe/GjKFnsbtWZdnVhIwMbQClKp3iBAx3A1ONMfe4B0TkHmCiiNyjLaKVUuXepk22tfSqVXnHP/nEtrEENu/ZQ+zo0ex1BQsRoaF89+ST9ChOfYWUFNvbweGwxZvat7c9HpSqRIoTMLQAHs039inwFtAM2OLrSSmllE9s22YDhfwl3oOC7BbGfv0A2LR7N7FjxrDPI1iYPXYssR07nvza+WsnXHoptGihtRNUpVWcgCEK+/jB0wnXr9V9Ox2llPKB1FSYMAFWrMg77nDYPg/9+0ODBgBs3L2b7h7BQmRYGLPHjqXbWWcVfu2sLFs7ISvL7nLQ2gmqiijuLolGItLC489BHuN5Eh+NMdt9MjOllCqJ1FT7CMKTiF0B6N/f5hS4bNi1i+5jxrA/MRGAauHhfD92LJeeeWbB66ak2G2RISG2S2S7dkW3w1SqEiluwDDzJONfFzKm2T1KqcCJiIDq1e0OBbBNmwYMsM2hPKzfuZPuY8ZwwFV3oVp4OD889RSXnHFG7knZ2TY3ISXFBgfdu9uKjMXoaqlUZVOcgEF3QiilKpaGDe02xsGDbV5BPn/v3En30aM5eOwYAFEREfzw1FNc3KGDPcGzZHOLFnZFoUEDfeygqrQiAwZjzHv+mIhSSvlMZCQ8+CDUrl3g0Lr4eLqPGcMhj2Dhx6ee4qIOHXJLNoeG2pLNbdvahEallNeVHpVSqsJaGxdH9zFjSDhu87iru4KFC+vVg127bIChJZuVKpQGDEqpKmHNjh30ePLJnGAhOjKSn0aM4PzoaJvfcMYZtu+ElmxWqlAaMCilKr3VO3bQY8wYDrsSIaPDw/n5oYfoeuONtr9DVFSAZ6hU+acBg1KqUlu1fTs9xozhSFISADUiI/n5gw8479pr9bGDUl7QgEEpVWmt2LCBns88Q2JKCgA1a9Rgzpw5dPnHPwI8M6UqHg0YlFKV0vLVq+k1dWpOsFCrVi3mzJlD586dAzwzpSomDRiUUpXOCaeTntOmcTQ5GbDBwi+//MK5554b4JkpVXFpwKCU8qkjR44wd+5c0tLSAvL99+7dy/bt23E6nQDExMTwyy+/0KlTp4DMR6nKQgMGpZTPfPHFFwwdOpQjrkZOgRYTE8PcuXM555xzAj0VpSo8rXOqlCq1pKQkhgwZQp8+fcpNsFC7dm1+/fVXDRaU8hFdYVBKlcrSpUsZOHAgW7duzRlr0qQJl112WUDm8/PPPxMaGsr8+fNpUUgfCaVUyWjAoJQqEafTyQsvvMDTTz+dky8A0L9/f9566y1qBqj1c7du3QA0WFDKxzRgUEp5LS4ujltvvZU//vgjZ6x69eq8+eabDBw4ENHyykpVOprDoJTyyowZMzj77LPzBAsXXXQRq1ev5tZbb9VgQalKSgMGpVSxHD16lAEDBnDrrbdy3NXAKSgoiHHjxjFv3jyaN28e4BkqpcqSPpJQShVpwYIFDBo0iJ07d+aMtWzZkhkzZtC1a9cAzkwp5S+6wqCUOqnMzExGjx5Nt27d8gQLd955JytXrtRgQakqRFcYlFKF2rx5MwMHDmTZsmU5Y7Vq1WLq1KncdNNNAZyZUioQdIVBKZWHMYapU6fSqVOnPMFC9+7dWbNmjQYLSlVRGjAopXIkJCRw4403MnToUFJcXR5DQkL473//y5w5c2jcuHGAZ6iUChR9JKGUAmyFxMGDB7Nv376csfbt2/PRRx9peWWllK4wKFXVpaWl8dBDD9G7d+88wcL999/PsmXLNFhQSgG6wqBUlbZu3ToGDBjA2rVrc8bq1avHtGnTuPrqqwM4M6VUeaMrDEpVQcYY3njjDbp06ZInWLjqqqtYs2aNBgtKqQJ0hUGpKmb//v3ccccd/Pjjjzlj4eHhjB8/nmHDhmlpZ6VUoTRgUKoK+fbbbxkyZAgJCQk5Y+eccw4zZsygQ4cOAZyZUqq800cSSlUBycnJ3HvvvVx//fV5goVHH32UxYsXa7CglCqS3wMGEekvIitEJElE9ojI+yJyWr5zaorINBE54jrvBxFpVci1OojIXBFJEZG9IjJORILynSMi8oSI7BKRVBFZICKa9q2qjOXLl9O5c2fefvvtnLFGjRoxd+5cXnnlFcLCwgI4O6VUReHXgEFErgM+BhYB1wOPAZcCs0XEcy6fAr2BB4EBQG1grohEe1yrFvALYFzXGgc8AjyT79v+G3gSeAm4FkgCfhGRBr5+f0qVJ06nk5deeonzzz+fTZs25Yz36dOHNWvW0L179wDOTilV0fg7h2EAsMIYM9w9ICLHgW+AtsAGEbkAuBzoaYyZ6zpnCbADGAr81/XSe4EI4EZjzHFgjiugeFpEXjbGHBeRcGzA8IIxZqLrWn8CccBwYExZv2GlAmHXrl0MGjSI+fPn54xFRUUxYcIEbr/9dk1sVEp5zd+PJEKAY/nGjrp+df8EOwfIBOa5TzDGHABWA557va4EfnIFC26fYIOIy1x/vhCIBj7zuFYy8J3r9UpVKikpKUyaNImOHTvmCRa6du3KqlWrGDx4sAYLSqkS8XfAMA24RERuE5FoEWkDPAv8aoxZ7zonHHAaY5z5XpsBtPf4cztgo+cJxpidQIrrmPscJ7Al37U2eJyjVIV3+PBhxo0bR7NmzRg+fDhHj9o43OFwMHbsWBYuXEjLli0DPEulVEXm10cSxpjZIjIYeBd4zzW8CLjO47StQLiInGWMWQsgIhHAmUB1j/Nqkbs64SnRdcx9TlIhwUciECkiocaYjFK8JaUCKi4ujldffZV33303p1mU2+mnn86HH37IRRddFKDZKaUqE38nPcYCk4HXgVigPxADfOWxu+EnbL7C2yLSVkQaul5TA8j2wxyHisgyEVl26NChsv52SpXIypUrGTBgAK1atWLChAl5goWmTZvy+uuvs27dOg0WlFI+4++kx/HAt8aYx9wDIrIK+2jheuBLY0yGiPTH7qZwP3L4HXgf8EzrTsQGEfnVch1znxMlIkH5VhlqASmFrS4YY6YAUwC6dOlivH+LSpUNYwxz587l5ZdfZs6cOQWOn3322YwaNYq+ffsSEhISgBkqpSozfwcM7bCBQA5jzCYRSQVaeowtddVdaANkGWO2icgsYLHHSzeSLw9BRJoAkeQGGhuBIKAVsMnj1AL5D0qVV1lZWcycOZOXX36ZlStXFjjeo0cPRo0aRa9evTShUSlVZvwdMMQD53oOiEh77M6GOM9xY4zBdZMXkdZAT2wdBbcfgJEiUt0Yc8I11g9IBdzp4YuA40BfbHIlIhLpus4UX70ppcpCcnIy06dPZ/z48cTFxeU55nA46Nu3LyNHjqRz586BmaBSqkrxd8AwGXhNRPZib/j1gbHYYOF790ki8iR2BSABOAtbeOkTY8ycfNd6APhSRF4CWgBPA6+6t1oaY9JE5EXgSRFJdF3zYWzuxoSye5tKlVxCQgITJ05k4sSJHD58OM+xiIgI7rzzTh5++GFatGgRoBkqpaoifwcMb2C3Rw7DFl46is1PeNxVH8GtNvB/QB1gF7ZY03jPCxljEkWkBzARW1fhKPAaNmjw9CI2QHjcdd1lQC9XbQelyo3t27fz6quvMm3aNFJTU/Mcq127NsOHD+f++++nbt26AZqhUqoq8/e2SgO85fo61Xn/Av5VjOutJ28i5Mm+53OuL6XKneXLl/PKK6/w+eefk52ddyPQ6aefziOPPMIdd9xBtWrVAjRDpZTS9tZKBYQxhjlz5vDyyy8zd+7cAsc7derEqFGj6NOnD8HB+r+pUirw9CeRUn6UlZXFZ599xssvv8zq1asLHO/VqxejRo2iR48euuNBKVWuaMCglB8kJyfz7rvv8uqrrxIfH5/nWFBQEDfffDMjR46kU6dOAZqhUkqdmgYMSpWhgwcPMnHiRCZNmsSRI0fyHIuIiOCuu+7ioYceonnz5gGaoVJKFY8GDEqVgW3btjF+/HimT59OWlpanmN16tRhxIgR3HfffdSpUydAM1RKKe9owKCUjxhj+O2335gwYQLffvttgR0PLVq04JFHHmHw4MFERkYGaJZKKVUyGjAoVUpJSUl8+OGHTJgwgfXr1xc43rlzZ0aNGsWNN96oOx6UUhWW/vRSqoS2bt3KpEmTmD59OseOHStwvHfv3owaNYrY2Fjd8aCUqvA0YFDKC9nZ2fz8889MmDCBH374AVsXLFdUVBS33347w4cPp127die5ilJKVTwaMChVDMePH+d///sfEydOZMuWLQWOt27dmuHDhzN48GCio6MDMEOllCpbGjAodQobN25k4sSJvPfeeyQlJeU5JiJceeWVjBgxgssvvxyHwxGgWSqlVNnTgEGpfJxOJ99//z0TJkxgzpw5BY5HR0dz5513cv/999OqVasAzFAppfxPAwalXBITE3n33Xd588032bFjR4HjHTp0YPjw4QwaNIioqKgAzFAppQJHAwZV5a1du5YJEybw4YcfFmgr7XA4uO666xgxYoTudlBKVWkaMKgqKSsri2+++YYJEyYwf/78AsdjYmK46667GDZsGKeffrr/J6iUUuWMBgyqSklISGDq1Km89dZb7Nq1q8Dxs88+mxEjRnDLLbdoNUallPKgAYOqElasWMGECRP4+OOPSU9Pz3MsKCiIG2+8kREjRnDxxRfrYwellCqEBgyq0srIyOCLL75g4sSJLFq0qMDxunXrMnToUO69914aN24cgBkqpVTFoQGDqnRSUlIYP348b731Fvv27StwvEuXLowYMYKbb76Z8PDwAMxQKaUqHg0YVKVy4sQJrrzySv7444884yEhIfTt25cRI0bQtWtXfeyglFJe0oBBVRrHjx/nyiuvzPP4oUGDBgwbNoyhQ4fSoEGDAM5OKaUqNg0YVKVw7NgxrrjiChYvXpwz9sILL/Dwww8TGhoawJkppVTloAGDqvCOHTtG7969WbJkSc7YG2+8wYgRIwI4K6WUqlw0YFAV2tGjR+nduzdLly7NGZs4cSL3339/AGellFKVjwYMqsJKTEzk8ssvZ9myZTljkyZN4r777gvgrJRSqnLSgEFVSImJifTq1Yvly5fnjL355psMGzYsgLNSSqnKSwMGVeEcOXKEXr16sWLFipyxyZMnc8899wRwVkopVblpwKAqlCNHjtCzZ09WrlyZMzZlyhTuvvvuAM5KKaUqPw0YVIVx+PBhevbsyapVqwAQEaZOncqQIUMCPDOllKr8NGBQFUJCQgI9e/Zk9erVgA0W3nnnHe68884Az0wppaoGDRhUuXfo0CF69OjB2rVrARssTJs2jcGDBwd2YkopVYVowKDKtYMHD9KjRw/WrVsH2GBh+vTp3H777QGemVJKVS0aMKhy6+DBg3Tv3p2///4bsMHCe++9x6BBgwI8M6WUqnocgZ6AUoU5cOAAsbGxOcGCw+Hg/fff12BBKaUCRFcYVLmzf/9+unfvzoYNGwAbLHzwwQcMGDAgwDNTSqmqSwMGVa7s27eP7t27s3HjRsAGCzNmzKB///4BnplSSlVtGjCocmPfvn3ExsayadMmAIKCgpgxYwb9+vUL8MyUUkppwKDKhb179xIbG8vmzZsBGyx89NFH3HzzzQGemVJKKdCAQZUDe/bsITY2li1btgA2WPjkk0/o06dPgGemlFLKTXdJqIDavXs33bp1ywkWgoOD+fTTTzVYUEqpckZXGFTA7Nq1i9jYWLZt2wbYYOGzzz7jhhtuCPDMlFJK5acBgwqInTt3Ehsby/bt2wEICQnh888/5/rrrw/wzJRSShVGAwbld/Hx8cTGxrJjxw7ABgszZ87kuuuuC/DMlFJKnYwGDMqv4uLiiI2NJS4uDrDBwhdffMG1114b2IkppZQ6Jb8nPYpIfxFZISJJIrJHRN4XkdPyndNQRKa7jieJyEoRGZjvnMEiYgr5ujffeSIiT4jILhFJFZEFInKOP96ryisuLo5u3brlBAuhoaF8+eWXGiwopVQF4NcVBhG5DvgYmASMBBoCzwKzRaSzMSZbRBzAt0BtYBSwH+gDfCgiqcaYL/NdtjuQ6vHn7fmO/xt40vX9NgIPA7+IyJnGmP0+fYPqpHbs2EG3bt3YuXMnYIOFr776iquuuirAM1NKKVUc/n4kMQBYYYwZ7h4QkePAN0BbYAPQBugCXGeM+c512lwR6Qr0A/IHDH8ZY5IK+2YiEo4NGF4wxkx0jf0JxAHDgTE+el/qFLZv3063bt3YtWsXAGFhYXz99ddcccUVAZ6ZUkqp4vL3I4kQ4Fi+saOuX8XjHE5ynuCdC4Fo4DP3gDEmGfgOuNLLa6kS2LZtW4Fg4ZtvvtFgQSmlKhh/BwzTgEtE5DYRiRaRNthHEr8aY9a7zlkHLAHGiUhr13mDgYuAyYVcc5uIZInIJhG5J9+xdoAT2JJvfIPrmCpDW7duzRMshIeH8+2339K7d+8Az0wppZS3/BowGGNmA4OBKdgVhE1AEHCTxzkG++nfAWx2nTcFuNMY86vH5fZhcxMGAdcCi4HJIvKQxzm1gCRjjDPfVBKBSBEJzT9HERkqIstEZNmhQ4dK8W6rti1bttCtWzd2794N2GDhu+++4/LLLw/wzJRSSpWEv5MeY7GrBK8DPwD1gaeBr0SkpzHG6Up6fB+b9NgPOAhcBbwrIoeNMT8CGGN+An7yuPwPrpyFMSLyujEmuyRzNMZMwQYodOnSxZTkGlWRMYZt27axZMkSli5dyqeffsqBAwcAiIiI4LvvvqNHjx4BnqVSSqmS8nfS43jgW2PMY+4BEVmF3b1wPTah8RrXVxtjjPtRwjwRaQK8DPx4iuvPBG4GTsfulkgEokQkKN8qQy0gxRiT4ZN3VQUdOnSIpUuXsnTp0pwgITExscB5ET3DbnwAAAxBSURBVBERzJo1i+7duwdglkoppXzF3wFDO+y2yhzGmE0ikgq09DgnxSNYcFsJFFUK0OT7dSP2kUcr7OMPz3ls9G7qVVdqaiorV67MCQyWLFmSU6XxVGrVqsUXX3xBbGysH2aplFKqLPk7YIgHzvUcEJH2QAR2q6P7nEgRaWuM8bzJd/Y452T6AAmuawAsAo4DfbHJlYhIJDbnYUpJ30Rllp2dzcaNG/OsHKxZs4asrKwiXxsTE8N5553HeeedR9euXbn44ouJjo72w6yVUkqVNX8HDJOB10RkL7k5DGOxgcD3rnO+B3YCX4vIOOAQcDX2UcP97guJyBfAUmANdhWhn+vrAXf+gjEmTUReBJ4UkURyCzc5gAll+k4riL179+YJDv766y9OnDhR5OvCwsLo1KlTTnBw3nnn0bJlS0S83fmqlFKqIvB3wPAGkAEMA+7F1lb4HXjcVR8BY8wJEekBvIDNeYgGtrnO91wV2ATcCTTB1mdYD9xmjPkg3/d8ERsgPI5NpFwG9DLGHCiLN1ienThxguXLl+cEB0uXLs3ZxVCUtm3b5gQGXbt2pWPHjoSGFthkopRSqpISu4tRFaZLly5m2bJlgZ5GiWRlZbFu3bo8qwfr168nO7vozSP169fPExx06dKFmjVr+mHWSpVet27dAJg3b15A56FURSUiy40xXfKPa7fKCiwzM5P4+Hi2bNnC1q1bc762bNnCjh07ipV3EBkZSefOnfMECE2aNNFHC0oppfLQgKGcy8jIIC4uLicQ8AwK4uLicDrz16Q6OYfDwRlnnJEnOOjQoQPBwfrPQCml1KnpnaIcSE9PZ8eOHXmCAffv4+PjvQoKPDVp0iRPUmLnzp2Jiory8eyVUkpVBRow+ElaWhrbt28vNCjYuXNnsXILCtOoUSNatWpFq1ataN26dc7vW7ZsqcGBUkopn9GAoQysWrWKX375JU9gsGvXLkqaYNqkSZOTBgWRkZE+nr1SSilVkAYMZeDXX39l5MiRxT5fRGjSpEmeYMD9+xYtWhAREVGGs1VKKaWKpgFDGWjdunWBMRGhWbNmBVYJWrduTfPmzQkPDw/ATJVSSqni0YChDHTs2JH77rsvT2DQvHlzwsLCAj01pZRSqkQ0YCgDzZo1Y9KkSYGehlJKKeUzjkBPQCmllFLlnwYMSimllCqSBgxKKaWUKpIGDEoppZQqkgYMSimllCqSBgxKKaWUKpIGDEoppZQqkgYMSimllCqSBgxKKaWUKpIGDEoppZQqkgYMSimllCqSBgxKKaWUKpIYYwI9h3JLRA4B8aW4RB0gwUfTqYj0/Vfd91+V3zvo+9f3X7HffzNjTN38gxowlCERWWaM6RLoeQSKvv+q+/6r8nsHff/6/ivn+9dHEkoppZQqkgYMSimllCqSBgxla0qgJxBg+v6rrqr83kHfv77/SkhzGJRSSilVJF1hUEoppVSRNGDwMRHpICJzRSRFRPaKyDgRCQr0vPxBRPqKyLciskdEkkRkuYjcEuh5BYqINHL9PRgRiQr0fPxBRIJF5N8iskVE0kVkt4i8Fuh5+YOI9BeRFa7/5ntE5H0ROS3Q8yoLItJKRN4WkTUi4hSReYWcIyLyhIjsEpFUEVkgIucEYLo+V9T7F5GGIvKKiKx2/XvYJSLvVfR/Dxow+JCI1AJ+AQxwPTAOeAR4JpDz8qOHgSTgIeA64DfgIxEZEdBZBc4r2L+PquR/wAPAf4HLgX8DqYGckD+IyHXAx8Ai7P/7jwGXArNFpDL+nD0DuArYBGw+yTn/Bp4EXgKuxf6/8IuINPDLDMtWUe+/M3AD9t/EtcBIoCuwqCJ/eNAcBh8SkceBUdiiF8ddY6OAp4EG7rHKSkTqGGMS8o19BFxgjGkeoGkFhIhcCnwNPI8NHKobYyp18CAiVwDfAWcbY9YHej7+JCKfAK2NMZ09xq4DvgE6GGM2BGxyZUBEHMaYbNfvZwJ1jDHdPI6HAweA8caYca6xakAc8LYxZozfJ+1DxXj/NYEkY0yWx1gbbIAx2Bjznp+n7BOVMfINpCuBn/IFBp8AEcBlgZmS/+QPFlxWAhV6Gc5brkdQE7ArTBW52pu37gR+rWrBgksIcCzf2FHXr+LnuZQ5983yFC4EooHPPF6TjA0oryzDqflFUe/fGHPUM1hwjW0GUqjAPw81YPCtdsBGzwFjzE7sP5J2AZlR4F3AyZcsK6t7gTBgUqAn4mddgc0iMlFEjrvyeL6s6M9ti2kacImI3CYi0a5Pk89SdQOodoAT2JJvfANV9GehiHQEIqnAPw81YPCtWuR+qvCU6DpWpYhID+CfwPhAz8VfRKQ28B/gYWNMZqDn42cNgMHAOUB/4A7ss9yvRKTSfcr2ZIyZjX3vU7ArDZuAIOCmAE4rkGphl+Sd+cYTgUgRCQ3AnALGlcfyOjaA+jbA0ymx4EBPQFVOInI68BHwjTHmfwGdjH89Byw2xnwf6IkEgLi+rjfGHAYQkX3AfKA7MDeAcytTIhILTMbeFH4A6mNzl74SkZ6F3DhV1fICdrX1sor8QUIDBt9KBGoUMl7LdaxKEJEY7A/NeGBggKfjNyJyBvY5/qWupCewS5AANUTEaYypzDsGEoHt7mDB5XcgA+hAJQ4YsKto3xpjHnMPiMgq7CPK64EvAzWxAEkEokQkKF+wVAtIMcZkBGhefici92F3SdxijFkS6PmUhj6S8K2N5Hs+JyJNsDeNjYW+opIRkUhgFhAKXGOMSQnwlPypNTb57U/sD8xEcvMYdmMTISuzDRSe4CdAUUlyFV07YJXngDFmE3ZLacuAzCiwNmIfybTKN14gz6syE5GbsP/fjzLGfBro+ZSWBgy+9QPQW0Sqe4z1w/7QmB+YKfmPiAQDn2NvnFcYYw4GeEr+9jsQm+/rJdexq7DbKyuzWcBZIlLHY+xSbBC1OjBT8pt44FzPARFpj90hFReICQXYIuA40Nc94PowcS3252SlJyLdgBnABGPMfwM8HZ/QRxK+NRlbtOZLEXkJaIF9jvlqZa/B4PIm9sb4IFDblQDottIYkx6YafmHa1vpPM8xVy4HwMLKXocBm/D3APCdiDwPVMcGTL8YY34P6MzK3mTgNRHZS24Ow1hssFDp8llcN/+rXH9sBESLSB/Xn783xqSIyIvAkyKSiF1VeBj7IbXCr7QV9f6BZtg6LBuBT0XkfI+XHzLGbPPbZH1ICzf5mIh0ACZiE1yOAu8AT1eFpCcRicP+j1KY5saYOP/NpnwQkcHAdKpA4SawJXOBN7B1RzKwhYseMsZU6hwe1y6Qe4Fh2EcQR7ErTo8bY7YHcm5lwRUI7zjJ4ebGmDjX38kT2L+T2sAy4AFjzEq/TLIMFfX+gW7Y/+8L854xZrDPJ+UHGjAopZRSqkiaw6CUUkqpImnAoJRSSqkiacCglFJKqSJpwKCUUkqpImnAoJT6//bu3zWKIA7D+POaFGphEawEQdBOJBCx1MI/QBQFq5BKsLbSImBnYaVgKYqkswgIIiepFMFCEhAFCwWxUURRUJGkuLHYPQjH4Qq5nPfj+cBxu8PcMtO9zOx9R5IaGRgkSVIjA4OkLUtS/uHzvu57t3MtaXRYh0HSlnVVsgNYpioHfXVT23opZS3JQWDPOBTwkSaJpaElbVkp5fnm+yTrwJfu9rrvSJbFlSadWxKSBqp7SyLJgXrL4mKSa0k+JfmRZCnJ7iSHkrSS/EzyNslCj2fOJnmQ5FuS30meJTk+0IlJY87AIGlYXAH2AQtUBzedpzrUaRl4CJwBXgJ3khzu/CjJHNXpiDPABeAs8BVYSXJ0kBOQxplbEpKGxbtSSmf1oFWvEMwD86WUJYAkL4BTwDngdd33OvABOFlK2aj7tYBXwCJwenBTkMaXKwyShsWjrvs39Xer01CfevkZ2A+QZBfVyZj3gXaS6STTQIAV4MR2D1qaFK4wSBoW3Udgb/ylfWd9PQNMUa0kLPZ6aJIdpZR2vwYpTSoDg6RR9h1oA7eAe706GBak/jAwSBpZpZRfSZ4Cs8Cq4UDaPgYGSaPuEvCE6kXJ28BHYC8wB0yVUi7/z8FJ48KXHiWNtFLKKnCM6q+UN4HHwA3gCFWQkNQHloaWJEmNXGGQJEmNDAySJKmRgUGSJDUyMEiSpEYGBkmS1MjAIEmSGhkYJElSIwODJElqZGCQJEmN/gAO4wkLWmR8iAAAAABJRU5ErkJggg==\n", 96 | "text/plain": [ 97 | "
" 98 | ] 99 | }, 100 | "metadata": { 101 | "needs_background": "light" 102 | }, 103 | "output_type": "display_data" 104 | } 105 | ], 106 | "source": [ 107 | "state = 2\n", 108 | "feat = 2\n", 109 | "x = np.arange(0, 21, 1)[:14]\n", 110 | "fig=plt.figure(figsize=(8, 6))\n", 111 | "\n", 112 | "# plt.plot(x[-7:], preds[idx, state,:,0], 'blue', alpha= 0.5, lw=4, linestyle = \"--\", label='Seq2Seq Prediction')\n", 113 | "plt.plot(x[-7:], w1_pred[state,:,feat], 'blue', alpha= 1, lw=3, linestyle = \"--\", label='FC prediction')\n", 114 | "# plt.plot(x[-7:], w1_pred_up[state,:,0], 'blue', alpha= 0.5, lw=1, linestyle = \"--\")#, label='Seq2Seq Prediction'\n", 115 | "# plt.plot(x[-7:], w1_pred_lower[state,:,0], 'blue', alpha= 0.5, lw=1, linestyle = \"--\")#, label='Seq2Seq Prediction'\n", 116 | "plt.fill_between(x[-7:], w1_pred_lower[state,:,feat], w1_pred_up[state,:,feat], color='b', alpha=.2)\n", 117 | "\n", 118 | "plt.plot(x[-7:], ode_preds[state,-7:,feat], 'red', alpha= 1, lw=3, linestyle = \"--\", label='AutoODE prediction')\n", 119 | "plt.fill_between(x[-7:], ode_preds[state,-7:,feat] - ode_std[state,:,feat], \n", 120 | " ode_preds[state,-7:,feat] + ode_std[state,:,feat], color='red', alpha=.3)\n", 121 | "\n", 122 | "\n", 123 | "#plt.plot(x[-7:], w1_pred_lower[state,:,0], 'blue', alpha= 0.5, lw=1, linestyle = \"--\", label='AutoODE')#, label='Seq2Seq Prediction'\n", 124 | "\n", 125 | "#plt.plot(x, trues[idx, state,:,0], 'black', alpha=1, lw=2, label= 'True #Infectives')\n", 126 | "plt.plot(x, w1_true[state,7:,feat], 'black', alpha=1, lw=3, label= 'True #Infectives')\n", 127 | "# plt.plot(x, w1_true[state,:,0], 'black', alpha=1, lw=2, label= 'True #Infectives')\n", 128 | "# plt.plot(x, w1_true[state,:,0], 'black', alpha=1, lw=2, label= 'True #Infectives')\n", 129 | "\n", 130 | "plt.axvline(7, color = \"black\")\n", 131 | "plt.xlabel('Time', size = 16)\n", 132 | "plt.ylabel('Population', size = 16)\n", 133 | "plt.xticks(size = 15)\n", 134 | "plt.yticks(size = 15)\n", 135 | "plt.xlim(-0.5,13.5)\n", 136 | "\n", 137 | "#plt.grid(b=True, c='w', lw=2, ls='-')\n", 138 | "plt.legend(fontsize = 16, loc = 2)\n", 139 | "\n", 140 | "plt.title( \"#Death - \" + us.index[state], fontsize = 20)#+ r\"$N=10000, I_0 = 10, \\beta = 0.8, \\gamma = 0.1$\"+ \"- 07/06~07/12\"\n", 141 | "plt.savefig(us.index[state] + \"_D.png\", dpi = 400 , bbox_inches = \"tight\") #\"SuEIRD_Corr_\" + us.index[idx] + \".png\" \n", 142 | "plt.show()" 143 | ] 144 | } 145 | ], 146 | "metadata": { 147 | "kernelspec": { 148 | "display_name": "MyEnv", 149 | "language": "python", 150 | "name": "myenv" 151 | }, 152 | "language_info": { 153 | "codemirror_mode": { 154 | "name": "ipython", 155 | "version": 3 156 | }, 157 | "file_extension": ".py", 158 | "mimetype": "text/x-python", 159 | "name": "python", 160 | "nbconvert_exporter": "python", 161 | "pygments_lexer": "ipython3", 162 | "version": "3.6.9" 163 | } 164 | }, 165 | "nbformat": 4, 166 | "nbformat_minor": 4 167 | } 168 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## Paper: 2 | Rui Wang, Danielle Maddix, Christos Faloutsos, Yuyang Wang, Rose Yu [Bridging Physics-based and Data-driven modeling for 3 | Learning Dynamical Systems](https://arxiv.org/pdf/2011.10616.pdf), Annual Conference on Learning for Dynamics and Control (L4DC), 2021 4 | 5 | ## Abstract: 6 | How can we learn a dynamical system to make forecasts, when some variables are unobserved? For instance, in COVID-19, we want to forecast the number of infected and death cases but we do not know the count of susceptible and exposed people. While mechanics compartment models are widely-used in epidemic modeling, data-driven models are emerging for disease forecasting. As a case study, we compare these two types of models for COVID-19 forecasting and notice that physics-based models significantly outperform deep learning models. We present a hybrid approach, AutoODE-COVID, which combines a novel compartmental model with automatic differentiation. Our method obtains a 57.4% reduction in mean absolute errors for 7-day ahead COVID-19 forecasting compared with the best deep learning competitor. To understand the inferior performance of deep learning, we investigate the generalization problem in forecasting. Through systematic experiments, we found that deep learning models fail to forecast under shifted distributions either in the data domain or the parameter domain. This calls attention to rethink generalization especially for learning dynamical systems. 7 | 8 | ## Description 9 | 1. ode_nn/: 10 | * DNN.py: Pytorch implementation of Seq2Seq, Auto-FC, Transformer, Neural ODE. 11 | * Graph.py: Pytorch implementation of Graph Attention, Graph Convolution. 12 | * AutoODE.py: Pytorch implementation of AutoODE(-COVID). 13 | * train.py: data loaders, train epoch, validation epoch, test epoch functions. 14 | 15 | 3. Run_DSL.ipynb: train deep sequence models and graph neural nets. 16 | 4. Run_AutoODE.ipynb: train AutoODE-COVID. 17 | 5. Evaluation.ipynb: evaluation functions and prediction visualization 18 | 19 | 20 | ## Requirement 21 | * python 3.6 22 | * pytorch 10.1 23 | * matplotlib 24 | * scipy 25 | * numpy 26 | * pandas 27 | * dgl 28 | 29 | 30 | ## Cite 31 | ``` 32 | @inproceedings{wang2020bridging, 33 | title={Bridging Physics-based and Data-driven modeling for Learning Dynamical Systems}, 34 | author={Rui Wang and Danielle Maddix and Christos Faloutsos and Yuyang Wang and Rose Yu}, 35 | journal={In proceedings of Learning for Dynamics and Control (L4DC)}, 36 | year={2021} 37 | } 38 | ``` 39 | -------------------------------------------------------------------------------- /Run_AutoODE.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [ 8 | { 9 | "name": "stderr", 10 | "output_type": "stream", 11 | "text": [ 12 | "Using backend: pytorch\n", 13 | "/global/homes/r/rwang2/.conda/envs/myenv/lib/python3.6/site-packages/dgl/base.py:45: DGLWarning: Detected an old version of PyTorch. Suggest using torch>=1.5.0 for the best experience.\n", 14 | " return warnings.warn(message, category=category, stacklevel=1)\n" 15 | ] 16 | } 17 | ], 18 | "source": [ 19 | "import torch\n", 20 | "import torch.nn as nn\n", 21 | "import os\n", 22 | "from ode_nn import AutoODE_COVID, weight_fun\n", 23 | "from ode_nn import Dataset, train_epoch, eval_epoch, get_lr\n", 24 | "import numpy as np\n", 25 | "import pandas as pd\n", 26 | "import torch.nn.functional as F\n", 27 | "from torch.utils import data\n", 28 | "import matplotlib.pyplot as plt\n", 29 | "import random\n", 30 | "import warnings\n", 31 | "from ode_nn import Dataset_graph, train_epoch_graph, eval_epoch_graph, get_lr\n", 32 | "warnings.filterwarnings(\"ignore\")\n", 33 | "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")" 34 | ] 35 | }, 36 | { 37 | "cell_type": "markdown", 38 | "metadata": {}, 39 | "source": [ 40 | "## Data Preprocessing" 41 | ] 42 | }, 43 | { 44 | "cell_type": "code", 45 | "execution_count": 12, 46 | "metadata": {}, 47 | "outputs": [], 48 | "source": [ 49 | "# Read and Preprocess the csv files from John Hopkins Dataset\n", 50 | "# https://github.com/CSSEGISandData/COVID-19/tree/master/csse_covid_19_data/csse_covid_19_daily_reports_us\n", 51 | "direc = \".../ODEs/Data/COVID/\" # Directory that contains daily report csv files.\n", 52 | "list_csv = sorted(os.listdir(direc))\n", 53 | "us = []\n", 54 | "for file in list_csv:\n", 55 | " sample = pd.read_csv(direc + file).set_index(\"Province_State\")[[\"Confirmed\", \"Recovered\", \"Deaths\"]].sort_values(by = \"Confirmed\", ascending = False)\n", 56 | " us.append(sample.drop(['Diamond Princess', 'Grand Princess']))\n", 57 | "us = pd.concat(us, axis=1, join='inner')\n", 58 | "us_data = us.values.reshape(56,-1,3)\n", 59 | "us_data[us_data!=us_data] = 0\n", 60 | "\n", 61 | "#####################################################################################\n", 62 | "# Normalize by total population of each state\n", 63 | "population = pd.read_csv(\".../ode_nn/population_states.csv\", index_col=0)\n", 64 | "scaler = population.loc[us.index].values.reshape(56, 1, 1)*1e6\n", 65 | "us_data = us_data/scaler\n", 66 | "us_data = torch.from_numpy(us_data).float().to(device)\n", 67 | "\n", 68 | "# Mobility Data: beta = 1 - stay_at_home_percentages\n", 69 | "beta = torch.load(\".../ode_nn/mobility/us_beta.pt\").float()\n", 70 | "\n", 71 | "# U.S states 1-0 Adjacency Matrix\n", 72 | "graph = torch.load(\".../ode_nn/mobility/us_graph.pt\").float()" 73 | ] 74 | }, 75 | { 76 | "cell_type": "markdown", 77 | "metadata": {}, 78 | "source": [ 79 | "## Train AutoODE-COVID" 80 | ] 81 | }, 82 | { 83 | "cell_type": "code", 84 | "execution_count": null, 85 | "metadata": {}, 86 | "outputs": [], 87 | "source": [ 88 | "##################################################################\n", 89 | "test_idx = 131\n", 90 | "\n", 91 | "# Learning Rate\n", 92 | "lr = 0.01\n", 93 | "\n", 94 | "# number of historic data points for fitting\n", 95 | "input_steps = 10 \n", 96 | "\n", 97 | "# forecasting horizon\n", 98 | "output_steps = 7\n", 99 | "\n", 100 | "# number of epochs for training\n", 101 | "num_epochs = 20000\n", 102 | "\n", 103 | "# select data for training\n", 104 | "data = us_data[:, test_idx-input_steps:test_idx+7]\n", 105 | "y_exact = data[:,:input_steps]\n", 106 | "\n", 107 | "##################################################################\n", 108 | "\n", 109 | "model = AutoODE_COVID(initial_I = data[:,0,0], initial_R = data[:,0,1], initial_D = data[:,0,2],\n", 110 | " num_regions = 56, solver = \"RK4\", n_breaks = 1, graph = graph).to(device)\n", 111 | "\n", 112 | "optimizer = torch.optim.Adam(model.parameters(), lr)\n", 113 | "scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size= 1000, gamma=0.9)\n", 114 | "loss_fun = torch.nn.MSELoss()\n", 115 | "min_loss = 1\n", 116 | "\n", 117 | "##################################################################\n", 118 | "\n", 119 | "for e in range(num_epochs):\n", 120 | " scheduler.step()\n", 121 | " y_approx = model(input_steps)\n", 122 | " loss = loss_fun(y_approx[:,:,-3:], y_exact[:,:input_steps,-3:])\n", 123 | " \n", 124 | "######## Weighted Loss ########\n", 125 | "# loss_weight = weight_fun(input_steps, function = \"sqrt\", feat_weight = True)\n", 126 | "# loss = torch.mean(loss_weight*loss_fun(y_approx[:,:,-3:], y_exact[:,:input_steps,-3:])) \n", 127 | "\n", 128 | "######## A few constraints that can potential improve the model ########\n", 129 | "# positive_constraint = loss_fun(F.relu(-model.beta), torch.tensor(0.0).float().to(device))\n", 130 | "# diagonal_constraint = loss_fun(torch.diagonal(model.A, 0),torch.tensor(1.0).float().to(device))\n", 131 | "# initial_constraint = loss_fun(model.init_S + model.init_E + model.init_I + model.init_R + model.init_U, torch.tensor(1.0).float().to(device))\n", 132 | "# loss += initial_constraint + positive_constraint + diagonal_constraint \n", 133 | " \n", 134 | " if loss.item() < min_loss:\n", 135 | " best_model = model\n", 136 | " min_loss = loss.item()\n", 137 | " optimizer.zero_grad()\n", 138 | " loss.backward(retain_graph=True)\n", 139 | " optimizer.step()\n", 140 | "# if e%1000 == 0:\n", 141 | "# y_approx2 = model(data.shape[1]).data.numpy()\n", 142 | "# y_exact2 = data.data.numpy()\n", 143 | "# print(list_csv[test_idx][:10])\n", 144 | "# #torch.mean(torch.abs(y_approx - y_exact)[:,-7:]).data, torch.mean(torch.abs(y_approx - y_exact)[:,30:]).data\n", 145 | "# for i in range(3):\n", 146 | "# print(np.mean(np.abs(y_approx2*scaler - y_exact2*scaler)[:,-7:, i]))\n", 147 | "\n", 148 | "########################################################################\n", 149 | "name = \"autoode-covid\"\n", 150 | "y_approx = best_model(data.shape[1]).data.numpy()\n", 151 | "y_exact = data.data.numpy()\n", 152 | "print(list_csv[test_idx][:10])\n", 153 | "#torch.mean(torch.abs(y_approx - y_exact)[:,-7:]).data, torch.mean(torch.abs(y_approx - y_exact)[:,30:]).data\n", 154 | "for i in range(3):\n", 155 | " print(np.mean(np.abs(y_approx*scaler - y_exact*scaler)[:,-7:, i]))\n", 156 | "\n", 157 | "torch.save({\"model\": best_model,\n", 158 | " \"preds\": y_approx*scaler,\n", 159 | " \"trues\": y_exact*scaler},\n", 160 | " \".pt\")" 161 | ] 162 | } 163 | ], 164 | "metadata": { 165 | "kernelspec": { 166 | "display_name": "MyEnv", 167 | "language": "python", 168 | "name": "myenv" 169 | }, 170 | "language_info": { 171 | "codemirror_mode": { 172 | "name": "ipython", 173 | "version": 3 174 | }, 175 | "file_extension": ".py", 176 | "mimetype": "text/x-python", 177 | "name": "python", 178 | "nbconvert_exporter": "python", 179 | "pygments_lexer": "ipython3", 180 | "version": "3.6.9" 181 | } 182 | }, 183 | "nbformat": 4, 184 | "nbformat_minor": 4 185 | } 186 | -------------------------------------------------------------------------------- /Run_DSL.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 6, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import torch\n", 10 | "import torch.nn as nn\n", 11 | "from ode_nn import Seq2Seq, Auto_FC, Transformer, Latent_ODE, Transformer_EncoderOnly, GAT, GCN\n", 12 | "from ode_nn import Dataset, train_epoch, eval_epoch, get_lr, Dataset_graph, train_epoch_graph, eval_epoch_graph\n", 13 | "import dgl\n", 14 | "import numpy as np\n", 15 | "import time\n", 16 | "from torch.utils import data\n", 17 | "import matplotlib.pyplot as plt\n", 18 | "import os\n", 19 | "import pandas as pd\n", 20 | "import warnings\n", 21 | "warnings.filterwarnings(\"ignore\")\n", 22 | "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")" 23 | ] 24 | }, 25 | { 26 | "cell_type": "markdown", 27 | "metadata": {}, 28 | "source": [ 29 | "## Data Loader" 30 | ] 31 | }, 32 | { 33 | "cell_type": "code", 34 | "execution_count": 2, 35 | "metadata": {}, 36 | "outputs": [], 37 | "source": [ 38 | "input_length = 14\n", 39 | "mid = 14\n", 40 | "output_length = 7\n", 41 | "batch_size = 128\n", 42 | "\n", 43 | "# Directories to the samples of subsequences\n", 44 | "train_direc = '/global/cscratch1/sd/rwang2/ODEs/Data/covid_seq/train/sample_'\n", 45 | "test_direc = '/global/cscratch1/sd/rwang2/ODEs/Data/covid_seq/test/sample_'\n", 46 | "\n", 47 | "train_indices = list(range(4000))\n", 48 | "valid_indices = list(range(4000, 5250))\n", 49 | "test_indices = list(range(150))\n", 50 | "\n", 51 | "train_set = Dataset(train_indices, input_length, mid, output_length, train_direc, entire_target = True)\n", 52 | "valid_set = Dataset(valid_indices, input_length, mid, output_length, train_direc, entire_target = True)\n", 53 | "test_set = Dataset(test_indices, input_length, mid, output_length, test_direc, entire_target = True)\n", 54 | "\n", 55 | "train_loader = data.DataLoader(train_set, batch_size = batch_size, shuffle = True)\n", 56 | "valid_loader = data.DataLoader(valid_set, batch_size = batch_size, shuffle = False)\n", 57 | "test_loader = data.DataLoader(test_set, batch_size = batch_size, shuffle = False)" 58 | ] 59 | }, 60 | { 61 | "cell_type": "markdown", 62 | "metadata": {}, 63 | "source": [ 64 | "## Train Deep Sequence Models" 65 | ] 66 | }, 67 | { 68 | "cell_type": "code", 69 | "execution_count": null, 70 | "metadata": {}, 71 | "outputs": [], 72 | "source": [ 73 | "name = \"...\"\n", 74 | "model = Auto_FC(input_length = input_length, input_dim = 3, output_dim = 3, hidden_dim = 16, quantile = True).to(device)\n", 75 | "# model = Seq2Seq(input_dim = 3, output_dim = 3, hidden_dim = 128, num_layers = 1, quantile = True).to(device)\n", 76 | "# model = Transformer(input_dim = 3, output_dim = 3, nhead = 4, d_model = 32, num_layers = 3, dim_feedforward = 64, quantile = True).to(device)\n", 77 | "# model = Latent_ODE(latent_dim = 64, obs_dim = 3, nhidden = 128, rhidden = 128, quantile = True, aug = False).to(device).to(device)\n", 78 | "\n", 79 | "####################################\n", 80 | "learning_rate = 0.01\n", 81 | "optimizer = torch.optim.Adam(model.parameters(), lr = learning_rate)\n", 82 | "scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size= 1, gamma=0.95)\n", 83 | "loss_fun = nn.MSELoss()\n", 84 | "print(sum(p.numel() for p in model.parameters() if p.requires_grad))\n", 85 | "train_rmse = []\n", 86 | "valid_rmse = []\n", 87 | "test_rmse = []\n", 88 | "min_rmse = 1\n", 89 | "\n", 90 | "for i in range(1, 200):\n", 91 | " start = time.time()\n", 92 | " scheduler.step()\n", 93 | " model.train()\n", 94 | " train_rmse.append(train_epoch(model, train_loader, optimizer, loss_fun)[-1])#, feed_tgt = True\n", 95 | " model.eval()\n", 96 | " preds, trues, rmse = eval_epoch(model, valid_loader, loss_fun, concat_input = True)\n", 97 | " valid_rmse.append(rmse)\n", 98 | " if valid_rmse[-1] < min_rmse:\n", 99 | " min_rmse = valid_rmse[-1] \n", 100 | " best_model = model \n", 101 | " torch.save(best_model, name + \".pth\")\n", 102 | " end = time.time()\n", 103 | " if (len(train_rmse) > 30 and np.mean(valid_rmse[-5:]) >= np.mean(valid_rmse[-10:-5])):\n", 104 | " break\n", 105 | " print(\"Epoch \" + str(i) + \": \", \"train rmse:\", train_rmse[-1], \"valid rmse:\",valid_rmse[-1], \n", 106 | " \"time:\",round((end-start)/60,3), \"Learning rate:\", format(get_lr(optimizer), \"5.2e\"))\n", 107 | "\n", 108 | "preds, trues, rmses = eval_epoch(best_model, test_loader, loss_fun, concat_input = False)\n", 109 | "\n", 110 | "torch.save({\"preds\": preds[:,-7:],\n", 111 | " \"trues\": trues,\n", 112 | " \"model\": best_model},\n", 113 | " name + \".pt\")" 114 | ] 115 | }, 116 | { 117 | "cell_type": "markdown", 118 | "metadata": {}, 119 | "source": [ 120 | "## Train Graphic Models" 121 | ] 122 | }, 123 | { 124 | "cell_type": "code", 125 | "execution_count": 4, 126 | "metadata": {}, 127 | "outputs": [], 128 | "source": [ 129 | "input_length = 14\n", 130 | "mid = 14\n", 131 | "output_length = 7\n", 132 | "batch_size = 4\n", 133 | "####################################\n", 134 | "train_direc = '.../graph_train/sample_'\n", 135 | "test_direc = '.../graph_test/sample_'\n", 136 | "\n", 137 | "train_indices = list(range(80))\n", 138 | "valid_indices = list(range(80, 100))\n", 139 | "test_indices = list(range(3))\n", 140 | "####################################\n", 141 | "train_set = Dataset_graph(train_indices, input_length, mid, output_length, train_direc)\n", 142 | "valid_set = Dataset_graph(valid_indices, input_length, mid, output_length, train_direc)\n", 143 | "test_set = Dataset_graph(test_indices, input_length, mid, output_length, test_direc)\n", 144 | "\n", 145 | "train_loader = data.DataLoader(train_set, batch_size = batch_size, shuffle = True)\n", 146 | "valid_loader = data.DataLoader(valid_set, batch_size = batch_size, shuffle = False)\n", 147 | "test_loader = data.DataLoader(test_set, batch_size = batch_size, shuffle = False)\n", 148 | "\n", 149 | "################################\n", 150 | "# U.S. states 1-0 adjacency matrix\n", 151 | "graph = torch.load(\"/global/cscratch1/sd/rwang2/ODEs/Main/ode_nn/mobility/us_graph.pt\")[:50,:50]\n", 152 | "G = dgl.DGLGraph().to(device)\n", 153 | "G.add_nodes(50)\n", 154 | "for i in range(50):\n", 155 | " for j in range(50):\n", 156 | " if graph[i,j] == 1:\n", 157 | " G.add_edge(i,j)" 158 | ] 159 | }, 160 | { 161 | "cell_type": "code", 162 | "execution_count": null, 163 | "metadata": {}, 164 | "outputs": [], 165 | "source": [ 166 | "name = \"covid_gcn\"\n", 167 | "model = GCN(in_dim = 42, out_dim = 3, hidden_dim = 16, num_layer = 3).to(device)\n", 168 | "#model = GAT(in_dim = 42, out_dim = 3, hidden_dim = 32, num_heads = 4, num_layer = 6).to(device)\n", 169 | "print(sum(p.numel() for p in model.parameters() if p.requires_grad))\n", 170 | "learning_rate = 0.01\n", 171 | "optimizer = torch.optim.Adam(model.parameters(), lr = learning_rate)\n", 172 | "scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size= 1, gamma=0.9)\n", 173 | "loss_fun = nn.MSELoss()\n", 174 | "sum(p.numel() for p in model.parameters() if p.requires_grad) \n", 175 | "train_rmse = []\n", 176 | "valid_rmse = []\n", 177 | "test_rmse = []\n", 178 | "min_rmse = 1\n", 179 | "for i in range(1, 200):\n", 180 | " start = time.time()\n", 181 | " scheduler.step()\n", 182 | " model.train()\n", 183 | " train_rmse.append(train_epoch_graph(model, train_loader, optimizer, loss_fun, G)[-1])\n", 184 | " model.eval()\n", 185 | " preds, trues, rmse = eval_epoch_graph(model, valid_loader, loss_fun, G)\n", 186 | " valid_rmse.append(rmse)\n", 187 | " if valid_rmse[-1] < min_rmse:\n", 188 | " min_rmse = valid_rmse[-1] \n", 189 | " best_model = model \n", 190 | " torch.save(best_model, name + \".pth\")\n", 191 | " end = time.time()\n", 192 | " if (len(train_rmse) > 30 and np.mean(valid_rmse[-5:]) >= np.mean(valid_rmse[-10:-5])):\n", 193 | " break\n", 194 | " print(\"Epoch \" + str(i) + \": \", \"train rmse:\", train_rmse[-1], \"valid rmse:\",valid_rmse[-1], \n", 195 | " \"time:\",round((end-start)/60,3), \"Learning rate:\", format(get_lr(optimizer), \"5.2e\"))\n", 196 | " \n", 197 | "preds, trues, rmse = eval_epoch_graph(best_model, test_loader, loss_fun, G)\n", 198 | "torch.save({\"preds\": preds,\n", 199 | " \"trues\": trues,\n", 200 | " \"rmse\": np.sqrt(np.mean((preds - trues[:,:,-7:])**2))}, \n", 201 | " name + \".pt\")" 202 | ] 203 | } 204 | ], 205 | "metadata": { 206 | "kernelspec": { 207 | "display_name": "MyEnv", 208 | "language": "python", 209 | "name": "myenv" 210 | }, 211 | "language_info": { 212 | "codemirror_mode": { 213 | "name": "ipython", 214 | "version": 3 215 | }, 216 | "file_extension": ".py", 217 | "mimetype": "text/x-python", 218 | "name": "python", 219 | "nbconvert_exporter": "python", 220 | "pygments_lexer": "ipython3", 221 | "version": "3.6.9" 222 | } 223 | }, 224 | "nbformat": 4, 225 | "nbformat_minor": 4 226 | } 227 | -------------------------------------------------------------------------------- /ode_nn/AutoODE.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import numpy as np 4 | from torch.utils import data 5 | import matplotlib.pyplot as plt 6 | import torch.nn.functional as F 7 | import random 8 | import warnings 9 | warnings.filterwarnings("ignore") 10 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 11 | 12 | def weight_fun(num_steps, function = "linear", feat_weight = False): 13 | if function == "linear": 14 | weight = torch.linspace(0, 1, num_steps).reshape(1,-1,1)*2 / num_steps 15 | if function == "sqrt": 16 | sqrt_func = lambda x: torch.sqrt(x) 17 | weight = sqrt_func(torch.linspace(0, 1, num_steps).reshape(1,-1,1)) / torch.sum(sqrt_func(torch.linspace(0, 1, num_steps))) 18 | if feat_weight: 19 | weight = weight.repeat(1,1,3) 20 | weight[:,:,2] = weight[:,:,2] #* 12 21 | weight[:,:,1] = weight[:,:,1] #* 4 22 | return weight 23 | 24 | class PiecewiseLinearModel(nn.Module): 25 | def __init__(self, n_breaks, num_regions): 26 | super(PiecewiseLinearModel, self).__init__() 27 | self.breaks = nn.Parameter(torch.rand((num_regions, 1, n_breaks))) 28 | self.linear = nn.Linear(n_breaks + 1, 1) 29 | def forward(self, xx): 30 | if len(xx.shape) < 3: 31 | xx = xx.unsqueeze(-1) 32 | out = torch.cat([xx, F.relu(xx - self.breaks)],2) 33 | return self.linear(out).squeeze(-1) 34 | 35 | class AutoODE_COVID(nn.Module): 36 | def __init__(self, initial_I, initial_R, initial_D, num_regions, 37 | solver = "Euler", Corr = False, low_rank = None, 38 | n_breaks = 0, graph = None, beta = None): 39 | super(AutoODE_COVID, self).__init__() 40 | self.num_regions = num_regions 41 | self.init_I = initial_I 42 | self.init_R = initial_R 43 | self.init_D = initial_D 44 | self.init_E = nn.Parameter(torch.tensor([0.5] * num_regions).float().to(device)) 45 | self.init_S = nn.Parameter(torch.tensor([0.5] * num_regions).float().to(device)) 46 | 47 | if Corr: 48 | if num_regions == 1: 49 | self.A = torch.ones(1, 1).to(device) 50 | else: 51 | if low_rank: 52 | if symmetric: 53 | self.B = nn.Parameter(torch.rand(num_regions, low_rank).to(device)) 54 | self.A = torch.mm(self.B, self.B.T) 55 | else: 56 | self.B = nn.Parameter(torch.rand(num_regions, low_rank).to(device)) 57 | self.C = nn.Parameter(torch.rand(low_rank, num_regions).to(device)) 58 | self.A = torch.mm(self.B, self.C) 59 | else: 60 | self.A = nn.Parameter(torch.rand(num_regions, num_regions).to(device)) 61 | else: 62 | self.A = np.zeros((num_regions, num_regions)) 63 | np.fill_diagonal(self.A, 1.0) 64 | self.A = torch.from_numpy(self.A).float().to(device) 65 | 66 | self.graph = graph 67 | if beta is None: 68 | if n_breaks > 0: 69 | self.plm = PiecewiseLinearModel(n_breaks = n_breaks, num_regions = num_regions) 70 | else: 71 | self.beta = nn.Parameter(torch.rand(num_regions).to(device)/10) 72 | else: 73 | self.beta = beta 74 | self.n_breaks = n_breaks 75 | self.gamma = nn.Parameter(torch.rand(num_regions).to(device)/10) 76 | self.sigma = nn.Parameter(torch.rand(num_regions).to(device)/10) 77 | self.mu = nn.Parameter(torch.rand(num_regions).to(device)/10) 78 | self.step = torch.tensor(0.01).float().to(device) 79 | self.a = nn.Parameter(torch.rand(num_regions).to(device)/10) 80 | self.b = nn.Parameter(torch.rand(num_regions).to(device)/10) 81 | self.solver = solver 82 | self.init_U = (1-self.mu)*self.sigma*self.init_E 83 | 84 | def Euler(self, num_steps): 85 | t = torch.linspace(1, num_steps, num_steps).repeat(self.num_regions, 1) 86 | if self.n_breaks > 0: 87 | beta = self.plm(t) 88 | else: 89 | beta = self.beta.repeat(1, num_steps) 90 | S_pred = [self.init_S] 91 | E_pred = [self.init_E] 92 | I_pred = [self.init_I] 93 | R_pred = [self.init_R] 94 | D_pred = [self.init_D] 95 | for n in range(num_steps - 1): 96 | if self.graph is None: 97 | S_pred.append(S_pred[n] - beta[:, n+1] * (torch.mm(self.A, ((I_pred[n] + E_pred[n]) * S_pred[n]).reshape(-1,1)).squeeze(1)) * self.step) 98 | E_pred.append(E_pred[n] + (beta[:, n+1] * S_pred[n] * (I_pred[n]+ E_pred[n]) - self.sigma * E_pred[n]) * self.step) 99 | else: 100 | S_pred.append(S_pred[n] - beta[:, n+1] * (torch.mm(self.graph*self.A, ((I_pred[n] + E_pred[n]) * S_pred[n]).reshape(-1,1)).squeeze(1)) * self.step) 101 | E_pred.append(E_pred[n] + (beta[:, n+1] * (torch.mm(self.graph*self.A, ((I_pred[n] + E_pred[n]) * S_pred[n]).reshape(-1,1)).squeeze(1)) - self.sigma * E_pred[n]) * self.step) 102 | 103 | I_pred.append(I_pred[n] + (self.mu * self.sigma * E_pred[n] - self.gamma*I_pred[n]) * self.step) 104 | R_pred.append(R_pred[n] + self.gamma * I_pred[n] * self.step) 105 | D_pred.append(D_pred[n] + self.a * torch.exp(- self.b * (n + 1) * self.step) * (R_pred[n+1] - R_pred[n])) 106 | y_pred = torch.cat([torch.stack(S_pred).transpose(0,1).unsqueeze(-1), 107 | (torch.stack(E_pred)*(1-self.mu.unsqueeze(0))*self.sigma.unsqueeze(0)).transpose(0,1).unsqueeze(-1), 108 | torch.stack(E_pred).transpose(0,1).unsqueeze(-1), 109 | torch.stack(I_pred).transpose(0,1).unsqueeze(-1), 110 | torch.stack(R_pred).transpose(0,1).unsqueeze(-1), 111 | torch.stack(D_pred).transpose(0,1).unsqueeze(-1)], dim = -1) 112 | return y_pred 113 | 114 | def f_S(self, S_n, I_n, E_n, beta, n): 115 | return -beta[:, n+1] * (torch.mm(self.A, ((I_n + E_n) * S_n).reshape(-1,1)).squeeze(1)) 116 | 117 | def f_E(self, S_n, I_n, E_n, beta, n): 118 | return beta[:, n+1] * S_n * (I_n + E_n) - self.sigma * E_n 119 | 120 | def f_I(self, I_n, E_n): 121 | return self.mu * self.sigma * E_n - self.gamma*I_n 122 | 123 | def f_R(self, I_n): 124 | return self.gamma*I_n 125 | 126 | def RK4_update(self, f_n, k1, k2, k3, k4): 127 | return f_n + 1/6 * (k1 + 2 * k2 + 2 * k3 + k4) * self.step 128 | 129 | def RK4(self, num_steps): 130 | 131 | t = torch.linspace(1, num_steps, num_steps).repeat(self.num_regions, 1) 132 | if self.n_breaks > 0: 133 | beta = self.plm(t) 134 | else: 135 | beta = self.beta.repeat(1, num_steps) 136 | 137 | S_pred = [self.init_S] 138 | E_pred = [self.init_E] 139 | I_pred = [self.init_I] 140 | R_pred = [self.init_R] 141 | D_pred = [self.init_D] 142 | for n in range(num_steps-1): 143 | # dt * f(t[n], y[n]) 144 | k1_S = self.f_S(S_pred[n], I_pred[n], E_pred[n], beta, n) 145 | k1_E = self.f_E(S_pred[n], I_pred[n], E_pred[n], beta, n) 146 | k1_I = self.f_I(I_pred[n], E_pred[n]) 147 | k1_R = self.f_R(I_pred[n]) 148 | 149 | # dt * f(t[n] + dt/2, y[n] + k1/2) 150 | S_plus_k1_half = S_pred[n] + k1_S / 2 * self.step 151 | I_plus_k1_half = I_pred[n] + k1_I / 2 * self.step 152 | E_plus_k1_half = E_pred[n] + k1_E / 2 * self.step 153 | 154 | k2_S = self.f_S(S_plus_k1_half, I_plus_k1_half, E_plus_k1_half, beta, n) 155 | k2_E = self.f_E(S_plus_k1_half, I_plus_k1_half, E_plus_k1_half, beta, n) 156 | k2_I = self.f_I(I_plus_k1_half, E_plus_k1_half) 157 | k2_R = self.f_R(I_plus_k1_half) 158 | 159 | # dt * f(t[n] + dt/2, y[n] + k2/2) 160 | S_plus_k2_half = S_pred[n] + k2_S / 2 * self.step 161 | I_plus_k2_half = I_pred[n] + k2_I / 2 * self.step 162 | E_plus_k2_half = E_pred[n] + k2_E / 2 * self.step 163 | 164 | k3_S = self.f_S(S_plus_k2_half, I_plus_k2_half, E_plus_k2_half, beta, n) 165 | k3_E = self.f_E(S_plus_k2_half, I_plus_k2_half, E_plus_k2_half, beta, n) 166 | k3_I = self.f_I(I_plus_k2_half, E_plus_k2_half) 167 | k3_R = self.f_R(I_plus_k2_half) 168 | 169 | # dt * f(t[n] + dt, y[n] + k3) 170 | S_plus_k3 = S_pred[n] + k3_S * self.step 171 | I_plus_k3 = I_pred[n] + k3_I * self.step 172 | E_plus_k3 = E_pred[n] + k3_E * self.step 173 | 174 | k4_S = self.f_S(S_plus_k3, I_plus_k3, E_plus_k3, beta, n) 175 | k4_E = self.f_E(S_plus_k3, I_plus_k3, E_plus_k3, beta, n) 176 | k4_I = self.f_I(I_plus_k3, E_plus_k3) 177 | k4_R = self.f_R(I_plus_k3) 178 | 179 | S_pred.append(self.RK4_update(S_pred[n], k1_S, k2_S, k3_S, k4_S)) 180 | E_pred.append(self.RK4_update(E_pred[n], k1_E, k2_E, k3_E, k4_E)) 181 | I_pred.append(self.RK4_update(I_pred[n], k1_I, k2_I, k3_I, k4_I)) 182 | R_pred.append(self.RK4_update(R_pred[n], k1_R, k2_R, k3_R, k4_R)) 183 | 184 | for n in range(num_steps - 1): 185 | D_pred.append(D_pred[n] + (self.a * (n * self.step) + self.b) * (R_pred[n+1] - R_pred[n])) 186 | 187 | y_pred = torch.cat([torch.stack(S_pred).transpose(0,1).unsqueeze(-1), 188 | (torch.stack(E_pred)*(1-self.mu.unsqueeze(0))*self.sigma.unsqueeze(0)).transpose(0,1).unsqueeze(-1), 189 | torch.stack(E_pred).transpose(0,1).unsqueeze(-1), 190 | torch.stack(I_pred).transpose(0,1).unsqueeze(-1), 191 | torch.stack(R_pred).transpose(0,1).unsqueeze(-1), 192 | torch.stack(D_pred).transpose(0,1).unsqueeze(-1)], dim = -1) 193 | return y_pred 194 | 195 | 196 | def forward(self, num_steps): 197 | if self.solver == "Euler": 198 | return self.Euler(num_steps)[:,:,-3:] 199 | elif self.solver == "RK4": 200 | return self.RK4(num_steps)[:,:,-3:] 201 | else: 202 | print("Error") 203 | 204 | 205 | ##############################################################################################################################3 206 | class Auto_ODE_SEIR(nn.Module): 207 | def __init__(self, initial, solver = "Euler"): 208 | super(Auto_ODE_SEIR, self).__init__() 209 | self.initial = torch.nn.Parameter(initial) #10000.0)#torch.tensor([0.99, 0., 0.01, 0.]).cuda()#torch.rand(4)/4 210 | self.beta = torch.nn.Parameter(torch.rand(1)) #torch.tensor(0.9).cuda()# 211 | self.gamma = torch.nn.Parameter(torch.rand(1)) #torch.tensor(0.1).cuda()# 212 | self.sigma = torch.nn.Parameter(torch.rand(1)) #torch.tensor(0.5).cuda()# 213 | self.step = torch.tensor(0.5) 214 | self.solver = solver 215 | 216 | def Euler(self, t): 217 | S_pred = [self.initial[0].reshape(-1,1)] 218 | E_pred = [self.initial[1].reshape(-1,1)] 219 | I_pred = [self.initial[2].reshape(-1,1)] 220 | R_pred = [self.initial[3].reshape(-1,1)] 221 | for n in range(t-1): 222 | S_pred.append((S_pred[n] - self.beta*S_pred[n]*I_pred[n]*self.step).reshape(-1,1)) 223 | E_pred.append((E_pred[n] + (self.beta*S_pred[n]*I_pred[n] - self.sigma*E_pred[n])*self.step).reshape(-1,1)) 224 | I_pred.append((I_pred[n] + (self.sigma*E_pred[n] - self.gamma*I_pred[n])*self.step).reshape(-1,1)) 225 | R_pred.append((R_pred[n] + self.gamma*I_pred[n]*self.step).reshape(-1,1)) 226 | y_pred = torch.cat([torch.cat(S_pred, dim = 0), 227 | torch.cat(E_pred, dim = 0), 228 | torch.cat(I_pred, dim = 0), 229 | torch.cat(R_pred, dim = 0)], dim = 1) 230 | return y_pred 231 | 232 | def RK4(self, t): 233 | S_pred = [self.initial[0].reshape(-1,1)] 234 | E_pred = [self.initial[1].reshape(-1,1)] 235 | I_pred = [self.initial[2].reshape(-1,1)] 236 | R_pred = [self.initial[3].reshape(-1,1)] 237 | 238 | for n in range(t-1): 239 | k1 = self.beta*S_pred[n]*I_pred[n]*self.step #dt * f(t[n], y[n]) 240 | k2 = (self.beta*(S_pred[n]+ k1/2)*I_pred[n])*self.step #dt * f(t[n] + dt/2, y[n] + k1/2) 241 | k3 = (self.beta*(S_pred[n]+ k2/2)*I_pred[n])*self.step #dt * f(t[n] + dt/2, y[n] + k2/2) 242 | k4 = (self.beta*(S_pred[n]+ k3)*I_pred[n])*self.step #dt * f(t[n] + dt, y[n] + k3) 243 | S_pred.append((S_pred[n] - 1/6 * (k1 + 2*k2 + 2*k3 + k4)).reshape(-1,1)) 244 | 245 | k1 = (self.beta*S_pred[n]*I_pred[n] - self.sigma*E_pred[n])*self.step #dt * f(t[n], y[n]) 246 | k2 = (self.beta*S_pred[n]*I_pred[n] - self.sigma*(E_pred[n] + k1/2))*self.step #dt * f(t[n] + dt/2, y[n] + k1/2) 247 | k3 = (self.beta*S_pred[n]*I_pred[n] - self.sigma*(E_pred[n] + k2/2))*self.step #dt * f(t[n] + dt/2, y[n] + k2/2) 248 | k4 = (self.beta*S_pred[n]*I_pred[n] - self.sigma*(E_pred[n] + k3))*self.step #dt * f(t[n] + dt, y[n] + k3) 249 | E_pred.append((E_pred[n] + 1/6 * (k1 + 2*k2 + 2*k3 + k4)).reshape(-1,1)) 250 | 251 | k1 = (self.sigma*E_pred[n] - self.gamma*I_pred[n])*self.step #dt * f(t[n], y[n]) 252 | k2 = (self.sigma*E_pred[n] - self.gamma*(I_pred[n] + k1/2))*self.step #dt * f(t[n] + dt/2, y[n] + k1/2) 253 | k3 = (self.sigma*E_pred[n] - self.gamma*(I_pred[n] + k2/2))*self.step #dt * f(t[n] + dt/2, y[n] + k2/2) 254 | k4 = (self.sigma*E_pred[n] - self.gamma*(I_pred[n] + k3))*self.step #dt * f(t[n] + dt, y[n] + k3) 255 | I_pred.append((I_pred[n] + 1/6 * (k1 + 2*k2 + 2*k3 + k4)).reshape(-1,1)) 256 | 257 | R_pred.append((R_pred[n] + self.gamma*I_pred[n]*self.step).reshape(-1,1)) 258 | 259 | y_pred = torch.cat([torch.cat(S_pred, dim = 0), 260 | torch.cat(E_pred, dim = 0), 261 | torch.cat(I_pred, dim = 0), 262 | torch.cat(R_pred, dim = 0)], dim = 1) 263 | return y_pred 264 | 265 | 266 | def forward(self, t): 267 | if self.solver == "Euler": 268 | return self.Euler(t) 269 | elif self.solver == "RK4": 270 | return self.RK4(t) 271 | else: 272 | print("Error") 273 | 274 | 275 | 276 | # def weight_fun(num_steps, function = "linear", feat_weight = False): 277 | # if function == "linear": 278 | # weight = torch.linspace(0, 1, num_steps).reshape(1,-1,1)*2/num_steps 279 | # if function == "sqrt": 280 | # sqrt_func = lambda x: torch.sqrt(x) 281 | # weight = sqrt_func(torch.linspace(0, 1, num_steps).reshape(1,-1,1))/torch.sum(sqrt_func(torch.linspace(0, 1, num_steps))) 282 | # return weight 283 | 284 | class Auto_ODE_LV(nn.Module): 285 | def __init__(self, num_time_series, p0): 286 | super(Auto_ODE_LV, self).__init__() 287 | self.num_time_series = num_time_series 288 | self.p0 = p0#nn.Parameter(torch.rand(num_time_series).float()) 289 | self.r = nn.Parameter(torch.rand(num_time_series).float()/10) 290 | self.k = nn.Parameter(torch.rand(num_time_series).float()*100) 291 | self.A = nn.Parameter(torch.rand(num_time_series, num_time_series).float()/10) 292 | 293 | def solve(self, num_steps): 294 | p = [] 295 | p.append(self.p0) 296 | for n in range(num_steps-1): # element-wise vector division and multiplication 297 | mat_vec_prod = torch.mm(self.A, p[n].reshape(-1, 1)).squeeze(-1) 298 | p.append((1 + self.r * (1 - mat_vec_prod )) * p[n])#/ self.k 299 | 300 | return torch.cat(p, dim=0).reshape(num_steps, self.num_time_series)#.T 301 | 302 | def forward(self, num_steps): 303 | return self.solve(num_steps) 304 | 305 | 306 | 307 | class Auto_ODE_FHN(nn.Module): 308 | def __init__(self, initials, solver = "Euler"): 309 | super(Auto_ODE_FHN, self).__init__() 310 | self.x0 = initials[0] 311 | self.y0 = initials[1] 312 | self.solver = solver 313 | #print(self.linear(torch.rand(1)).shape) 314 | 315 | self.a = nn.Parameter(torch.tensor(0.3).float())#self.embed_a(torch.rand(1).float())# 316 | self.b = nn.Parameter(torch.tensor(0.3).float())#self.embed_b(torch.rand(1).float())# 317 | self.c = nn.Parameter(torch.tensor(2.0).float())#self.embed_c(torch.rand(1).float())#n 318 | self.step = torch.tensor(0.1)#)#torch.tensor(1.0)#nn.Parameter( 319 | 320 | def Euler(self, num_steps): 321 | #self.a, self.b, self.c, self.d = self.embed(torch.rand(1).reshape(1,1))[0] 322 | x = [self.x0.reshape(-1,1)] 323 | y = [self.y0.reshape(-1,1)] 324 | for n in range(num_steps-1): 325 | x.append(x[n] + (self.c*(x[n] + y[n] - x[n]*x[n]*x[n]/3)).reshape(-1,1)*self.step) 326 | y.append(y[n] - ((1/self.c)*(x[n] + self.b*y[n] - self.a)).reshape(-1,1)*self.step) 327 | #print((self.c*(x[n] + y[n] - x[n]**3/3)).reshape(-1,1).shape) 328 | return torch.cat([torch.cat(x, dim = 0), torch.cat(y, dim = 0)], dim = 1) 329 | 330 | def f_x(self, x_n, y_n, c): 331 | return (c*(x_n + y_n - x_n**3/3)).reshape(-1,1) 332 | 333 | def f_y(self, x_n, y_n, a, b, c): 334 | return - ((1/c)*(x_n + b*y_n - a)).reshape(-1,1) 335 | 336 | def RK4_update(self, f_n, k1, k2, k3, k4): 337 | return f_n + 1/6 * (k1 + 2 * k2 + 2 * k3 + k4) * self.step 338 | 339 | def RK4(self, num_steps): 340 | #self.a, self.b, self.c, self.d = self.embed(torch.rand(1).reshape(1,1))[0] 341 | x = [self.x0.reshape(-1,1)] 342 | y = [self.y0.reshape(-1,1)] 343 | 344 | for n in range(num_steps-1): 345 | # dt * f(t[n], y[n]) 346 | k1_x = self.f_x(x[n], y[n], self.c) 347 | k1_y = self.f_y(x[n], y[n], self.a, self.b, self.c) 348 | 349 | # dt * f(t[n] + dt/2, y[n] + k1/2) 350 | x_plus_k1_half = x[n] + k1_x / 2 * self.step 351 | y_plus_k1_half = y[n] + k1_y / 2 * self.step 352 | 353 | k2_x = self.f_x(x_plus_k1_half, y_plus_k1_half, self.c) 354 | k2_y = self.f_y(x_plus_k1_half, y_plus_k1_half, self.a, self.b, self.c) 355 | 356 | # dt * f(t[n] + dt/2, y[n] + k2/2) 357 | x_plus_k2_half = x[n] + k2_x / 2 * self.step 358 | y_plus_k2_half = y[n] + k2_y / 2 * self.step 359 | 360 | k3_x = self.f_x(x_plus_k2_half, y_plus_k2_half, self.c) 361 | k3_y = self.f_y(x_plus_k2_half, y_plus_k2_half, self.a, self.b, self.c) 362 | 363 | # dt * f(t[n] + dt, y[n] + k3) 364 | x_plus_k3 = x[n] + k3_x * self.step 365 | y_plus_k3 = y[n] + k3_y * self.step 366 | 367 | k4_x = self.f_x(x_plus_k3, y_plus_k3, self.c) 368 | k4_y = self.f_y(x_plus_k3, y_plus_k3, self.a, self.b, self.c) 369 | 370 | 371 | x.append(self.RK4_update(x[n], k1_x, k2_x, k3_x, k4_x)) 372 | y.append(self.RK4_update(y[n], k1_y, k2_y, k3_y, k4_y)) 373 | #print(x[-1], y[-1]) 374 | return torch.cat([torch.cat(x, dim = 0), torch.cat(y, dim = 0)], dim = 1) 375 | def forward(self, num_steps): 376 | if self.solver == "Euler": 377 | return self.Euler(num_steps) 378 | elif self.solver == "RK4": 379 | return self.RK4(num_steps) 380 | else: 381 | print("error") 382 | -------------------------------------------------------------------------------- /ode_nn/DNN.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import numpy as np 4 | #from ode_nn.torchdiffeq import odeint_adjoint as odeint 5 | from ode_nn.torchdiffeq import odeint 6 | 7 | # from torch.nn import TransformerEncoder, TransformerEncoderLayer 8 | 9 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 10 | 11 | ######## Auto-FC ######## 12 | class Auto_FC(nn.Module): 13 | def __init__(self, input_length, input_dim, output_dim, hidden_dim, quantile = False): 14 | super(Auto_FC, self).__init__() 15 | self.quantile = quantile 16 | self.input_dim = input_dim 17 | self.output_dim = output_dim 18 | self.input_length = input_length 19 | self.model = nn.Sequential( 20 | nn.Linear(input_length*input_dim, hidden_dim), 21 | nn.LeakyReLU(), 22 | nn.Linear(hidden_dim, hidden_dim), 23 | nn.LeakyReLU(), 24 | nn.Linear(hidden_dim, hidden_dim), 25 | nn.LeakyReLU(), 26 | nn.Linear(hidden_dim, hidden_dim), 27 | nn.LeakyReLU(), 28 | nn.Linear(hidden_dim, hidden_dim), 29 | nn.LeakyReLU(), 30 | nn.Linear(hidden_dim, output_dim) 31 | ) 32 | if self.quantile: 33 | self.quantile = nn.Linear(1, 3) 34 | 35 | def forward(self, xx, output_length): 36 | xx = xx.reshape(xx.shape[0], -1) 37 | outputs = [] 38 | for i in range(output_length): 39 | out = self.model(xx) 40 | xx = torch.cat([xx[:, self.input_dim:], out], dim = 1) 41 | outputs.append(out.unsqueeze(1)) 42 | out = torch.cat(outputs, dim = 1) 43 | if self.quantile: 44 | out = self.quantile(out.unsqueeze(-1)) 45 | return out 46 | 47 | ######## Seq2Seq ######## 48 | class Encoder(nn.Module): 49 | def __init__(self, input_dim, hidden_dim, num_layers, dropout_rate = 0): 50 | super(Encoder, self).__init__() 51 | self.num_layers = num_layers 52 | self.hidden_dim = hidden_dim 53 | self.lstm = nn.LSTM(input_dim, hidden_dim, num_layers=num_layers, 54 | dropout = dropout_rate, batch_first = True) 55 | 56 | def forward(self, source): 57 | outputs, hidden = self.lstm(source) 58 | return outputs, hidden 59 | 60 | 61 | class Decoder(nn.Module): 62 | def __init__(self, output_dim, hidden_dim, num_layers, dropout_rate = 0): 63 | super(Decoder, self).__init__() 64 | self.output_dim = output_dim 65 | self.lstm = nn.LSTM(output_dim, hidden_dim, num_layers = num_layers, 66 | dropout = dropout_rate, batch_first = True) 67 | 68 | self.out = nn.Linear(hidden_dim, output_dim) 69 | 70 | def forward(self, x, hidden): 71 | output, hidden = self.lstm(x, hidden) 72 | prediction = self.out(output.float()) 73 | return prediction, hidden 74 | 75 | class Seq2Seq(nn.Module): 76 | def __init__(self, input_dim, output_dim, hidden_dim, num_layers, quantile = False): 77 | super(Seq2Seq, self).__init__() 78 | self.encoder = Encoder(input_dim = input_dim, hidden_dim = hidden_dim, num_layers = num_layers).to(device) 79 | self.decoder = Decoder(output_dim = output_dim, hidden_dim = hidden_dim, num_layers = num_layers).to(device) 80 | self.output_dim = output_dim 81 | self.quantile = quantile 82 | if self.quantile: 83 | self.quantile = nn.Linear(1, 3) 84 | 85 | def forward(self, source, target_length): 86 | batch_size = source.size(0) 87 | input_length = source.size(1) 88 | output_dim = self.decoder.output_dim 89 | encoder_output, encoder_hidden = self.encoder(source) 90 | 91 | decoder_output = torch.zeros((batch_size, 1, output_dim), device = device) 92 | decoder_hidden = encoder_hidden 93 | 94 | outputs = [] 95 | for t in range(target_length): 96 | decoder_output, decoder_hidden = self.decoder(decoder_output, decoder_hidden) 97 | outputs.append(decoder_output) 98 | out = torch.cat(outputs, dim = 1) 99 | if self.quantile: 100 | out = self.quantile(out.unsqueeze(-1)) 101 | return out 102 | 103 | 104 | ######## Transformer ######## 105 | #Tranformer Encoder Only 106 | class Transformer_EncoderOnly(nn.Module): 107 | def __init__(self, input_dim, output_dim, nhead = 4, d_model = 128, num_layers = 6, dim_feedforward = 256): 108 | super(Transformer_EncoderOnly, self).__init__() 109 | encoder_layer = nn.TransformerEncoderLayer(d_model, nhead, dim_feedforward, dropout = 0)#.to(device) 110 | decoder_layer = nn.TransformerDecoderLayer(d_model, nhead, dim_feedforward, dropout = 0)#.to(device) 111 | self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers = num_layers)#.to(device) 112 | #self.transformer_decoder = nn.TransformerDecoder(decoder_layer, num_layers = num_layers)#.to(device) 113 | self.embedding = nn.Linear(input_dim, d_model) 114 | self.output_layer = nn.Linear(d_model, output_dim) 115 | self.output_dim = output_dim 116 | 117 | def _generate_square_subsequent_mask(self, sz): 118 | mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1) 119 | mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0)) 120 | #mask = mask.float().masked_fill(mask == 0, float('-inf'))# 121 | return mask 122 | 123 | def forward(self, xx, output_length, yy = None): 124 | src = self.embedding(xx).transpose(0,1) 125 | src_mask = self._generate_square_subsequent_mask(src.shape[0]).to(device) 126 | encoder_output = self.transformer_encoder(src, mask = src_mask)# 127 | outputs = [] 128 | if output_length == xx.shape[1]: 129 | out = self.output_layer(encoder_output).transpose(0,1) 130 | return out 131 | for i in range(output_length): 132 | out = self.output_layer(encoder_output).transpose(0,1)[:,-1:] 133 | xx = torch.cat([xx[:,1:], out], dim = 1) 134 | outputs.append(out) 135 | return torch.cat(outputs, dim = 1) 136 | 137 | 138 | #Tranformer Encoder-Decoder 139 | class Transformer(nn.Module): 140 | def __init__(self, input_dim, output_dim, nhead = 4, d_model = 128, num_layers = 6, dim_feedforward = 256, quantile = False): 141 | super(Transformer, self).__init__() 142 | encoder_layer = nn.TransformerEncoderLayer(d_model, nhead, dim_feedforward, dropout = 0) 143 | decoder_layer = nn.TransformerDecoderLayer(d_model, nhead, dim_feedforward, dropout = 0) 144 | self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers = num_layers) 145 | self.transformer_decoder = nn.TransformerDecoder(decoder_layer, num_layers = num_layers) 146 | self.embedding = nn.Linear(input_dim, d_model) 147 | self.output_layer = nn.Linear(d_model, output_dim) 148 | self.output_dim = output_dim 149 | self.quantile = quantile 150 | if self.quantile: 151 | self.quantile = nn.Linear(1, 3) 152 | 153 | def _generate_square_subsequent_mask(self, sz): 154 | mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1) 155 | mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0)) 156 | return mask 157 | 158 | # teacher forcing: feed true yy during training 159 | # yy is None during inference 160 | def forward(self, xx, output_length, yy = None): 161 | src = self.embedding(xx) 162 | src_mask = self._generate_square_subsequent_mask(src.shape[0]).to(device) 163 | encoder_output = self.transformer_encoder(src, mask = src_mask) 164 | decoder_output = [] 165 | # greedy decode 166 | if yy is None: 167 | tgt = xx[:,-1:].transpose(0,1) 168 | for i in range(output_length): 169 | tgt_mask = self._generate_square_subsequent_mask(tgt.shape[0]).to(device) 170 | out = self.transformer_decoder(self.embedding(tgt), encoder_output, tgt_mask = tgt_mask) 171 | tgt = torch.cat([tgt, self.output_layer(out[-1:])], dim = 0) 172 | out = tgt[1:].transpose(0,1) 173 | else: 174 | tgt_beg = xx[:,-1:].transpose(0,1) 175 | tgt = self.embedding(torch.cat([tgt_beg, yy.transpose(0,1)[:-1]], dim = 0)) 176 | tgt_mask = self._generate_square_subsequent_mask(tgt.shape[0]).to(device) 177 | out = self.transformer_decoder(tgt, encoder_output, tgt_mask = tgt_mask) 178 | out = self.output_layer(out).transpose(0,1) 179 | 180 | if self.quantile: 181 | out = self.quantile(out.unsqueeze(-1)) 182 | return out 183 | 184 | ########### Latent ODE ######### 185 | class Latent_ODE(nn.Module): 186 | def __init__(self, latent_dim=4, obs_dim=2, nhidden=20, rhidden = 20, aug = False, aug_dim = 2, quantile = False): 187 | super(Latent_ODE, self).__init__() 188 | self.aug = aug 189 | self.aug_dim = aug_dim 190 | if self.aug: 191 | self.rec = RecognitionRNN(latent_dim, obs_dim+aug_dim, rhidden) 192 | else: 193 | self.rec = RecognitionRNN(latent_dim, obs_dim, rhidden) 194 | 195 | self.func = LatentODEfunc(latent_dim, nhidden) 196 | self.dec = LatentODEDecoder(latent_dim, obs_dim, nhidden) 197 | self.quantile = quantile 198 | if self.quantile: 199 | self.quantile = nn.Linear(1, 3) 200 | 201 | def forward(self, xx, output_length): 202 | time_steps = torch.arange(0, output_length, 0.01).float().to(device)[:output_length]#torch.linspace(0, 59, 60).float().to(device)[:output_length] 203 | if self.aug: 204 | aug_ten = torch.zeros(xx.shape[0], xx.shape[1], self.aug_dim).float().to(device) 205 | xx = torch.cat([xx, aug_ten], dim = -1) 206 | z0 = self.rec.forward(torch.flip(xx, [1])) 207 | pred_z = odeint(self.func, z0, time_steps).permute(1, 0, 2) 208 | out = self.dec(pred_z) 209 | if self.quantile: 210 | out = self.quantile(out.unsqueeze(-1)) 211 | return out 212 | 213 | class LatentODEfunc(nn.Module): 214 | def __init__(self, latent_dim=4, nhidden=20): 215 | super(LatentODEfunc, self).__init__() 216 | self.model = nn.Sequential( 217 | nn.Linear(latent_dim, nhidden), 218 | nn.ELU(), 219 | nn.Linear(nhidden, nhidden), 220 | nn.ELU(), 221 | nn.Linear(nhidden, nhidden), 222 | nn.ELU(), 223 | nn.Linear(nhidden, nhidden), 224 | nn.ELU(), 225 | nn.Linear(nhidden, latent_dim) 226 | ) 227 | self.nfe = 0 228 | 229 | def forward(self, t, x): 230 | self.nfe += 1 231 | out = self.model(x) 232 | return out 233 | 234 | class RecognitionRNN(nn.Module): 235 | def __init__(self, latent_dim=4, obs_dim=2, nhidden=25): 236 | super(RecognitionRNN, self).__init__() 237 | self.nhidden = nhidden 238 | self.model = nn.GRU(obs_dim, nhidden, batch_first = True) 239 | self.linear = nn.Linear(nhidden, latent_dim) 240 | 241 | def forward(self, x): 242 | #h0 = torch.zeros(1, x.shape[0], self.nhidden).to(device) 243 | output, hn = self.model(x)#, h0 244 | return self.linear(hn[0]) 245 | 246 | class LatentODEDecoder(nn.Module): 247 | def __init__(self, latent_dim=4, obs_dim=2, nhidden=20): 248 | super(LatentODEDecoder, self).__init__() 249 | self.model = nn.Sequential( 250 | nn.Linear(latent_dim, nhidden), 251 | nn.ReLU(), 252 | nn.Linear(nhidden, obs_dim) 253 | ) 254 | 255 | def forward(self, z): 256 | out = self.model(z) 257 | return out 258 | 259 | 260 | 261 | ##### Neural Encoder + ODE Solvers ######### 262 | class Neural_ODE(nn.Module): 263 | def __init__(self, input_dim, input_length, hidden_dim, solver = "Euler", encoder = "fc"): 264 | super(Neural_ODE, self).__init__() 265 | 266 | if encoder == "fc": 267 | self.encode_fc = nn.Sequential( 268 | nn.Linear(input_length*input_dim, hidden_dim), 269 | nn.LeakyReLU(), 270 | nn.Linear(hidden_dim, hidden_dim), 271 | nn.LeakyReLU(), 272 | nn.Linear(hidden_dim, hidden_dim), 273 | nn.LeakyReLU(), 274 | nn.Linear(hidden_dim, hidden_dim), 275 | nn.LeakyReLU(), 276 | nn.Linear(hidden_dim, hidden_dim), 277 | nn.LeakyReLU(), 278 | nn.Linear(hidden_dim, hidden_dim), 279 | nn.LeakyReLU(), 280 | nn.Linear(hidden_dim, 3) 281 | ) 282 | else: 283 | self.encode_lstm_1 = nn.LSTM(input_dim, hidden_dim, num_layers = 3, 284 | bidirectional = True, batch_first = True) 285 | self.encode_lstm_2 = nn.Linear(hidden_dim*2, 3) 286 | 287 | self.step = torch.tensor(1)#0.5 288 | self.encoder = encoder 289 | self.solver = solver 290 | 291 | 292 | def f_S(self, S_n, I_n, E_n, beta): 293 | return -beta*S_n*I_n 294 | 295 | def f_E(self, S_n, I_n, E_n, beta, sigma): 296 | return beta*S_n*I_n - sigma*E_n 297 | 298 | def f_I(self, I_n, E_n, sigma, gamma): 299 | return sigma*E_n - gamma*I_n 300 | #self.mu * self.sigma * E_n - self.gamma*I_n 301 | 302 | def f_R(self, I_n, gamma): 303 | return gamma*I_n 304 | 305 | # dt is included in the ks 306 | def RK4_update(self, f_n, k1, k2, k3, k4): 307 | return f_n + 1/6 * (k1 + 2 * k2 + 2 * k3 + k4) * self.step 308 | 309 | def RK4(self, t, initial, beta, gamma, sigma = None): 310 | 311 | S_pred = [initial[:,0:1]] 312 | E_pred = [initial[:,1:2]] 313 | I_pred = [initial[:,2:3]] 314 | R_pred = [initial[:,3:4]] 315 | for n in range(len(t)-1): 316 | # dt * f(t[n], y[n]) 317 | k1_S = self.f_S(S_pred[n], I_pred[n], E_pred[n], beta) 318 | k1_E = self.f_E(S_pred[n], I_pred[n], E_pred[n], beta, sigma) 319 | k1_I = self.f_I(I_pred[n], E_pred[n], sigma, gamma) 320 | k1_R = self.f_R(I_pred[n], gamma) 321 | 322 | # dt * f(t[n] + dt/2, y[n] + k1/2) 323 | S_plus_k1_half = S_pred[n] + k1_S / 2 * self.step 324 | I_plus_k1_half = I_pred[n] + k1_I / 2 * self.step 325 | E_plus_k1_half = E_pred[n] + k1_E / 2 * self.step 326 | 327 | k2_S = self.f_S(S_plus_k1_half, I_plus_k1_half, E_plus_k1_half, beta) 328 | k2_E = self.f_E(S_plus_k1_half, I_plus_k1_half, E_plus_k1_half, beta, sigma) 329 | k2_I = self.f_I(I_plus_k1_half, E_plus_k1_half, sigma, gamma) 330 | k2_R = self.f_R(I_plus_k1_half, gamma) 331 | 332 | # dt * f(t[n] + dt/2, y[n] + k2/2) 333 | S_plus_k2_half = S_pred[n] + k2_S / 2 * self.step 334 | I_plus_k2_half = I_pred[n] + k2_I / 2 * self.step 335 | E_plus_k2_half = E_pred[n] + k2_E / 2 * self.step 336 | 337 | k3_S = self.f_S(S_plus_k2_half, I_plus_k2_half, E_plus_k2_half, beta) 338 | k3_E = self.f_E(S_plus_k2_half, I_plus_k2_half, E_plus_k2_half, beta, sigma) 339 | k3_I = self.f_I(I_plus_k2_half, E_plus_k2_half, sigma, gamma) 340 | k3_R = self.f_R(I_plus_k2_half, gamma) 341 | 342 | # dt * f(t[n] + dt, y[n] + k3) 343 | S_plus_k3 = S_pred[n] + k3_S * self.step 344 | I_plus_k3 = I_pred[n] + k3_I * self.step 345 | E_plus_k3 = E_pred[n] + k3_E * self.step 346 | 347 | k4_S = self.f_S(S_plus_k3, I_plus_k3, E_plus_k3, beta) 348 | k4_E = self.f_E(S_plus_k3, I_plus_k3, E_plus_k3, beta, sigma) 349 | k4_I = self.f_I(I_plus_k3, E_plus_k3, sigma, gamma) 350 | k4_R = self.f_R(I_plus_k3, gamma) 351 | 352 | 353 | S_pred.append(self.RK4_update(S_pred[n], k1_S, k2_S, k3_S, k4_S)) 354 | E_pred.append(self.RK4_update(E_pred[n], k1_E, k2_E, k3_E, k4_E)) 355 | I_pred.append(self.RK4_update(I_pred[n], k1_I, k2_I, k3_I, k4_I)) 356 | R_pred.append(self.RK4_update(R_pred[n], k1_R, k2_R, k3_R, k4_R)) 357 | 358 | y_pred = torch.cat([torch.cat(S_pred, dim = 1).unsqueeze(-1), 359 | torch.cat(E_pred, dim = 1).unsqueeze(-1), 360 | torch.cat(I_pred, dim = 1).unsqueeze(-1), 361 | torch.cat(R_pred, dim = 1).unsqueeze(-1)], dim = 2) 362 | return y_pred 363 | 364 | 365 | 366 | def forward(self, xx, output_length): 367 | if self.encoder == "fc": 368 | out = self.encode_fc(xx.reshape(xx.shape[0], -1)) 369 | elif self.encoder == "lstm": 370 | out = self.encode_lstm_2(self.encode_lstm_1(xx)[0][:,-1]) 371 | else: 372 | return "Error" 373 | 374 | 375 | t = torch.linspace(0, 60//2, 61).float().cuda()[:output_length] 376 | 377 | if self.solver == "Euler": 378 | return self.Euler(t, xx[:,0], out[:,0:1], out[:,1:2], out[:,2:3]) 379 | elif self.solver == "RK4": 380 | return self.RK4(t, xx[:,0], out[:,0:1], out[:,1:2], out[:,2:3]) 381 | else: 382 | return "Error" -------------------------------------------------------------------------------- /ode_nn/Graph.py: -------------------------------------------------------------------------------- 1 | import math 2 | from dgl.nn import GraphConv, GatedGraphConv, GATConv 3 | import dgl 4 | import torch 5 | import torch.nn as nn 6 | 7 | class mySequential(nn.Sequential): 8 | def forward(self, *input): 9 | for module in self._modules.values(): 10 | input = module(*input) 11 | return input 12 | 13 | class GCN_Layer(nn.Module): 14 | def __init__(self, in_dim, out_dim, is_final_layer = False): 15 | super(GCN_Layer, self).__init__() 16 | self.gcn_conv = GraphConv(in_dim, out_dim, norm='both')# 17 | self.activation = nn.LeakyReLU() 18 | self.is_final_layer = is_final_layer 19 | 20 | def forward(self, g, xx): 21 | g = dgl.transform.remove_self_loop(g) 22 | g = dgl.transform.add_self_loop(g) 23 | if self.is_final_layer: 24 | return self.gcn_conv(g, xx) 25 | else: 26 | out = self.gcn_conv(g, xx) 27 | out = self.activation(out) 28 | return g, out 29 | 30 | class GAT_Layer(nn.Module): 31 | def __init__(self, in_dim, out_dim, num_heads = 4, is_final_layer = False): 32 | super(GAT_Layer, self).__init__() 33 | self.gat = GATConv(in_dim, out_dim, num_heads=num_heads) 34 | self.batchnorm = nn.BatchNorm1d(out_dim * num_heads) 35 | self.activation = nn.LeakyReLU() 36 | self.is_final_layer = is_final_layer 37 | 38 | def forward(self, g, xx): 39 | if self.is_final_layer: 40 | 41 | return torch.stack([self.gat(g, xx[i]) for i in range(xx.shape[0])]) 42 | else: 43 | out = torch.stack([self.gat(g, xx[i]) for i in range(xx.shape[0])]) 44 | out = self.activation(self.batchnorm(out.flatten(2).transpose(1,2)).transpose(1,2)) 45 | return g, out 46 | 47 | class GAT(nn.Module): 48 | def __init__(self, in_dim, out_dim, hidden_dim, num_heads, num_layer = 5, quantile = False): 49 | super(GAT, self).__init__() 50 | self.model = [GAT_Layer(in_dim, hidden_dim)] 51 | self.model += [GAT_Layer(hidden_dim*num_heads, hidden_dim) for i in range(num_layer-2)] 52 | self.output_layer = nn.Linear(hidden_dim*num_heads, out_dim) 53 | self.model = mySequential(*self.model) 54 | self.quantile = quantile 55 | if self.quantile: 56 | self.quantile = nn.Linear(1, 3) 57 | 58 | def forward(self, g, xx, output_length): 59 | outputs = [] 60 | for i in range(output_length): 61 | out = self.output_layer(self.model(g, xx)[1]) 62 | xx = torch.cat([xx[:,:,3:], out], dim = -1) 63 | outputs.append(out.unsqueeze(2)) 64 | out = torch.cat(outputs, dim = 2) 65 | #print(out.unsqueeze(-1).shape) 66 | if self.quantile: 67 | out = self.quantile(out.unsqueeze(-1)) 68 | return out 69 | 70 | 71 | 72 | class GCN(nn.Module): 73 | def __init__(self, in_dim, out_dim, hidden_dim, num_layer = 5, quantile = False): 74 | super(GCN, self).__init__() 75 | self.model = [GCN_Layer(in_dim, hidden_dim)] 76 | self.model += [GCN_Layer(hidden_dim, hidden_dim) for i in range(num_layer-2)] 77 | self.model += [GCN_Layer(hidden_dim, out_dim, is_final_layer = True)] 78 | self.model = mySequential(*self.model) 79 | self.quantile = quantile 80 | if self.quantile: 81 | self.quantile = nn.Linear(1, 3) 82 | 83 | def forward(self, g, xx, output_length): 84 | xx = xx.transpose(0,1) 85 | outputs = [] 86 | for i in range(output_length): 87 | out = self.model(g, xx) 88 | xx = torch.cat([xx[:,:,3:], out], dim = -1) 89 | outputs.append(out.unsqueeze(2)) 90 | out = torch.cat(outputs, dim = 2).transpose(0,1) 91 | 92 | if self.quantile: 93 | out = self.quantile(out.unsqueeze(-1)) 94 | return out -------------------------------------------------------------------------------- /ode_nn/__init__.py: -------------------------------------------------------------------------------- 1 | #from .neural_odes import * 2 | from .DNN import * 3 | from .train import * 4 | from .Graph import * 5 | from .torchdiffeq import odeint, odeint_adjoint 6 | from .AutoODE import * 7 | -------------------------------------------------------------------------------- /ode_nn/__pycache__/AdjMask_SuEIR.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Rose-STL-Lab/AutoODE-DSL/a64fa003fa252daaab6a4731df14b9f5bb007dc0/ode_nn/__pycache__/AdjMask_SuEIR.cpython-36.pyc -------------------------------------------------------------------------------- /ode_nn/__pycache__/AdjMask_SuEIR.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Rose-STL-Lab/AutoODE-DSL/a64fa003fa252daaab6a4731df14b9f5bb007dc0/ode_nn/__pycache__/AdjMask_SuEIR.cpython-37.pyc -------------------------------------------------------------------------------- /ode_nn/__pycache__/AutoODE.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Rose-STL-Lab/AutoODE-DSL/a64fa003fa252daaab6a4731df14b9f5bb007dc0/ode_nn/__pycache__/AutoODE.cpython-36.pyc -------------------------------------------------------------------------------- /ode_nn/__pycache__/DNN.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Rose-STL-Lab/AutoODE-DSL/a64fa003fa252daaab6a4731df14b9f5bb007dc0/ode_nn/__pycache__/DNN.cpython-36.pyc -------------------------------------------------------------------------------- /ode_nn/__pycache__/DNN.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Rose-STL-Lab/AutoODE-DSL/a64fa003fa252daaab6a4731df14b9f5bb007dc0/ode_nn/__pycache__/DNN.cpython-37.pyc -------------------------------------------------------------------------------- /ode_nn/__pycache__/Graph.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Rose-STL-Lab/AutoODE-DSL/a64fa003fa252daaab6a4731df14b9f5bb007dc0/ode_nn/__pycache__/Graph.cpython-36.pyc -------------------------------------------------------------------------------- /ode_nn/__pycache__/Graph.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Rose-STL-Lab/AutoODE-DSL/a64fa003fa252daaab6a4731df14b9f5bb007dc0/ode_nn/__pycache__/Graph.cpython-37.pyc -------------------------------------------------------------------------------- /ode_nn/__pycache__/LSTM.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Rose-STL-Lab/AutoODE-DSL/a64fa003fa252daaab6a4731df14b9f5bb007dc0/ode_nn/__pycache__/LSTM.cpython-36.pyc -------------------------------------------------------------------------------- /ode_nn/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Rose-STL-Lab/AutoODE-DSL/a64fa003fa252daaab6a4731df14b9f5bb007dc0/ode_nn/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /ode_nn/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Rose-STL-Lab/AutoODE-DSL/a64fa003fa252daaab6a4731df14b9f5bb007dc0/ode_nn/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /ode_nn/__pycache__/graph_models.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Rose-STL-Lab/AutoODE-DSL/a64fa003fa252daaab6a4731df14b9f5bb007dc0/ode_nn/__pycache__/graph_models.cpython-36.pyc -------------------------------------------------------------------------------- /ode_nn/__pycache__/neural_odes.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Rose-STL-Lab/AutoODE-DSL/a64fa003fa252daaab6a4731df14b9f5bb007dc0/ode_nn/__pycache__/neural_odes.cpython-36.pyc -------------------------------------------------------------------------------- /ode_nn/__pycache__/seq_models.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Rose-STL-Lab/AutoODE-DSL/a64fa003fa252daaab6a4731df14b9f5bb007dc0/ode_nn/__pycache__/seq_models.cpython-36.pyc -------------------------------------------------------------------------------- /ode_nn/__pycache__/train.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Rose-STL-Lab/AutoODE-DSL/a64fa003fa252daaab6a4731df14b9f5bb007dc0/ode_nn/__pycache__/train.cpython-36.pyc -------------------------------------------------------------------------------- /ode_nn/__pycache__/train.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Rose-STL-Lab/AutoODE-DSL/a64fa003fa252daaab6a4731df14b9f5bb007dc0/ode_nn/__pycache__/train.cpython-37.pyc -------------------------------------------------------------------------------- /ode_nn/mobility/Mobility.py: -------------------------------------------------------------------------------- 1 | stayhome = {"AK": 29.3, 2 | "AL": 23.8, 3 | "AR": 24.2, 4 | "AZ": 34.2, 5 | "CA": 35.6, 6 | "CO": 30.9, 7 | "CT": 32.6, 8 | "DC": 40.1, 9 | "DE": 31.9, 10 | "FL": 31.6, 11 | "GA": 27.8, 12 | "HI": 30.3, 13 | "IA": 25.7, 14 | "ID": 29.1, 15 | "IL": 30.6, 16 | "IN": 27.4, 17 | "KS": 26.4, 18 | "KY": 26.3, 19 | "LA": 25.2, 20 | "MA": 34.7, 21 | "MD": 34.6, 22 | "ME": 30.5, 23 | "MI": 28.3, 24 | "MN": 30.3, 25 | "MO": 26.5, 26 | "MS": 23.0, 27 | "MT": 28.8, 28 | "NC": 28.5, 29 | "ND": 26.4, 30 | "NE": 26.1, 31 | "NH": 31.4, 32 | "NJ": 33.3, 33 | "NM": 31.7, 34 | "NV": 33.8, 35 | "NY": 35.4, 36 | "OH": 27.7, 37 | "OK": 23.8, 38 | "OR": 33.1, 39 | "PA": 30.8, 40 | "RI": 32.7, 41 | "SC": 26.6, 42 | "SD": 26.1, 43 | "TN": 26.5, 44 | "TX": 31.0, 45 | "UT": 30.2, 46 | "VA": 31.7, 47 | "VT": 32.2, 48 | "WA": 34.2, 49 | "WI": 28.8, 50 | "WV": 26.7, 51 | "WY": 28.6} 52 | 53 | us_abbr = np.array(["NY", "NJ", "MA", "MI", "PA", 54 | "CA", "IL", "FL", "LA", "TX", 55 | "CT", "GA", "WA", "MD", "ID", 56 | "CO", "OH", "VA", "TN", "NC", 57 | "MO", "AL", "AZ", "WI", "SC", 58 | "NV", "MS", "RI", "UT", "OK", "KY", 59 | "DC", "DE", "IA", "MN", "OR", 60 | "ID", "AR", "KS", "NM", "NH", 61 | "PR", "SD", "NE", "VT", "ME", 62 | "WV", "HI", "MT", "ND", "AK", 63 | "WY", "GU", "VI", "MP", "AS"]) 64 | us_beta = [] 65 | for state in us_abbr: 66 | try: 67 | us_beta.append(1-stayhome[state]/100) 68 | except: 69 | us_beta.append(1-0.3) 70 | torch.save(torch.from_numpy(np.array(us_beta)).float(), "us_beta.pt") -------------------------------------------------------------------------------- /ode_nn/mobility/us_beta.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Rose-STL-Lab/AutoODE-DSL/a64fa003fa252daaab6a4731df14b9f5bb007dc0/ode_nn/mobility/us_beta.pt -------------------------------------------------------------------------------- /ode_nn/mobility/us_graph.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Rose-STL-Lab/AutoODE-DSL/a64fa003fa252daaab6a4731df14b9f5bb007dc0/ode_nn/mobility/us_graph.pt -------------------------------------------------------------------------------- /ode_nn/population_states.csv: -------------------------------------------------------------------------------- 1 | State,2018 Population 2 | California,39.77683 3 | Texas,28.70433 4 | Florida,21.312211 5 | New York,19.862512 6 | Pennsylvania,12.823989 7 | Illinois,12.76832 8 | Ohio,11.694664 9 | Georgia,10.545138 10 | North Carolina,10.390149 11 | Michigan,9.991177 12 | New Jersey,9.032872 13 | Virginia,8.52566 14 | Washington,7.530552 15 | Arizona,7.123898 16 | Massachusetts,6.895917 17 | Tennessee,6.782564 18 | Indiana,6.699629 19 | Missouri,6.135888 20 | Maryland,6.079602 21 | Wisconsin,5.818049 22 | Colorado,5.684203 23 | Minnesota,5.628162 24 | South Carolina,5.088916 25 | Alabama,4.888949 26 | Louisiana,4.682509 27 | Kentucky,4.472265 28 | Oregon,4.199563 29 | Oklahoma,3.940521 30 | Connecticut,3.588683 31 | Iowa,3.160553 32 | Utah,3.159345 33 | Nevada,3.056824 34 | Arkansas,3.020327 35 | Mississippi,2.982785 36 | Kansas,2.918515 37 | New Mexico,2.090708 38 | Nebraska,1.932549 39 | West Virginia,1.803077 40 | Idaho,1.75386 41 | Hawaii,1.426393 42 | New Hampshire,1.350575 43 | Maine,1.341582 44 | Montana,1.06233 45 | Rhode Island,1.061712 46 | Delaware,0.97118 47 | South Dakota,0.87779 48 | North Dakota,0.755238 49 | Alaska,0.738068 50 | District of Columbia,0.703608 51 | Vermont,0.62396 52 | Wyoming,0.57372 53 | American Samoa,0.055312 54 | Guam,0.167294 55 | Northern Mariana Islands,0.057215999999999996 56 | Puerto Rico,3.194 57 | Virgin Islands,0.10663099999999999 58 | -------------------------------------------------------------------------------- /ode_nn/torchdiffeq/__init__.py: -------------------------------------------------------------------------------- 1 | from ._impl import odeint 2 | from ._impl import odeint_adjoint 3 | -------------------------------------------------------------------------------- /ode_nn/torchdiffeq/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Rose-STL-Lab/AutoODE-DSL/a64fa003fa252daaab6a4731df14b9f5bb007dc0/ode_nn/torchdiffeq/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /ode_nn/torchdiffeq/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Rose-STL-Lab/AutoODE-DSL/a64fa003fa252daaab6a4731df14b9f5bb007dc0/ode_nn/torchdiffeq/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /ode_nn/torchdiffeq/_impl/__init__.py: -------------------------------------------------------------------------------- 1 | from .odeint import odeint 2 | from .adjoint import odeint_adjoint 3 | -------------------------------------------------------------------------------- /ode_nn/torchdiffeq/_impl/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Rose-STL-Lab/AutoODE-DSL/a64fa003fa252daaab6a4731df14b9f5bb007dc0/ode_nn/torchdiffeq/_impl/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /ode_nn/torchdiffeq/_impl/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Rose-STL-Lab/AutoODE-DSL/a64fa003fa252daaab6a4731df14b9f5bb007dc0/ode_nn/torchdiffeq/_impl/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /ode_nn/torchdiffeq/_impl/__pycache__/adams.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Rose-STL-Lab/AutoODE-DSL/a64fa003fa252daaab6a4731df14b9f5bb007dc0/ode_nn/torchdiffeq/_impl/__pycache__/adams.cpython-36.pyc -------------------------------------------------------------------------------- /ode_nn/torchdiffeq/_impl/__pycache__/adams.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Rose-STL-Lab/AutoODE-DSL/a64fa003fa252daaab6a4731df14b9f5bb007dc0/ode_nn/torchdiffeq/_impl/__pycache__/adams.cpython-37.pyc -------------------------------------------------------------------------------- /ode_nn/torchdiffeq/_impl/__pycache__/adaptive_heun.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Rose-STL-Lab/AutoODE-DSL/a64fa003fa252daaab6a4731df14b9f5bb007dc0/ode_nn/torchdiffeq/_impl/__pycache__/adaptive_heun.cpython-36.pyc -------------------------------------------------------------------------------- /ode_nn/torchdiffeq/_impl/__pycache__/adaptive_heun.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Rose-STL-Lab/AutoODE-DSL/a64fa003fa252daaab6a4731df14b9f5bb007dc0/ode_nn/torchdiffeq/_impl/__pycache__/adaptive_heun.cpython-37.pyc -------------------------------------------------------------------------------- /ode_nn/torchdiffeq/_impl/__pycache__/adjoint.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Rose-STL-Lab/AutoODE-DSL/a64fa003fa252daaab6a4731df14b9f5bb007dc0/ode_nn/torchdiffeq/_impl/__pycache__/adjoint.cpython-36.pyc -------------------------------------------------------------------------------- /ode_nn/torchdiffeq/_impl/__pycache__/adjoint.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Rose-STL-Lab/AutoODE-DSL/a64fa003fa252daaab6a4731df14b9f5bb007dc0/ode_nn/torchdiffeq/_impl/__pycache__/adjoint.cpython-37.pyc -------------------------------------------------------------------------------- /ode_nn/torchdiffeq/_impl/__pycache__/bosh3.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Rose-STL-Lab/AutoODE-DSL/a64fa003fa252daaab6a4731df14b9f5bb007dc0/ode_nn/torchdiffeq/_impl/__pycache__/bosh3.cpython-36.pyc -------------------------------------------------------------------------------- /ode_nn/torchdiffeq/_impl/__pycache__/bosh3.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Rose-STL-Lab/AutoODE-DSL/a64fa003fa252daaab6a4731df14b9f5bb007dc0/ode_nn/torchdiffeq/_impl/__pycache__/bosh3.cpython-37.pyc -------------------------------------------------------------------------------- /ode_nn/torchdiffeq/_impl/__pycache__/dopri5.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Rose-STL-Lab/AutoODE-DSL/a64fa003fa252daaab6a4731df14b9f5bb007dc0/ode_nn/torchdiffeq/_impl/__pycache__/dopri5.cpython-36.pyc -------------------------------------------------------------------------------- /ode_nn/torchdiffeq/_impl/__pycache__/dopri5.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Rose-STL-Lab/AutoODE-DSL/a64fa003fa252daaab6a4731df14b9f5bb007dc0/ode_nn/torchdiffeq/_impl/__pycache__/dopri5.cpython-37.pyc -------------------------------------------------------------------------------- /ode_nn/torchdiffeq/_impl/__pycache__/dopri8.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Rose-STL-Lab/AutoODE-DSL/a64fa003fa252daaab6a4731df14b9f5bb007dc0/ode_nn/torchdiffeq/_impl/__pycache__/dopri8.cpython-36.pyc -------------------------------------------------------------------------------- /ode_nn/torchdiffeq/_impl/__pycache__/dopri8.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Rose-STL-Lab/AutoODE-DSL/a64fa003fa252daaab6a4731df14b9f5bb007dc0/ode_nn/torchdiffeq/_impl/__pycache__/dopri8.cpython-37.pyc -------------------------------------------------------------------------------- /ode_nn/torchdiffeq/_impl/__pycache__/dopri8_coefficients.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Rose-STL-Lab/AutoODE-DSL/a64fa003fa252daaab6a4731df14b9f5bb007dc0/ode_nn/torchdiffeq/_impl/__pycache__/dopri8_coefficients.cpython-36.pyc -------------------------------------------------------------------------------- /ode_nn/torchdiffeq/_impl/__pycache__/dopri8_coefficients.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Rose-STL-Lab/AutoODE-DSL/a64fa003fa252daaab6a4731df14b9f5bb007dc0/ode_nn/torchdiffeq/_impl/__pycache__/dopri8_coefficients.cpython-37.pyc -------------------------------------------------------------------------------- /ode_nn/torchdiffeq/_impl/__pycache__/fixed_adams.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Rose-STL-Lab/AutoODE-DSL/a64fa003fa252daaab6a4731df14b9f5bb007dc0/ode_nn/torchdiffeq/_impl/__pycache__/fixed_adams.cpython-36.pyc -------------------------------------------------------------------------------- /ode_nn/torchdiffeq/_impl/__pycache__/fixed_adams.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Rose-STL-Lab/AutoODE-DSL/a64fa003fa252daaab6a4731df14b9f5bb007dc0/ode_nn/torchdiffeq/_impl/__pycache__/fixed_adams.cpython-37.pyc -------------------------------------------------------------------------------- /ode_nn/torchdiffeq/_impl/__pycache__/fixed_grid.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Rose-STL-Lab/AutoODE-DSL/a64fa003fa252daaab6a4731df14b9f5bb007dc0/ode_nn/torchdiffeq/_impl/__pycache__/fixed_grid.cpython-36.pyc -------------------------------------------------------------------------------- /ode_nn/torchdiffeq/_impl/__pycache__/fixed_grid.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Rose-STL-Lab/AutoODE-DSL/a64fa003fa252daaab6a4731df14b9f5bb007dc0/ode_nn/torchdiffeq/_impl/__pycache__/fixed_grid.cpython-37.pyc -------------------------------------------------------------------------------- /ode_nn/torchdiffeq/_impl/__pycache__/interp.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Rose-STL-Lab/AutoODE-DSL/a64fa003fa252daaab6a4731df14b9f5bb007dc0/ode_nn/torchdiffeq/_impl/__pycache__/interp.cpython-36.pyc -------------------------------------------------------------------------------- /ode_nn/torchdiffeq/_impl/__pycache__/interp.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Rose-STL-Lab/AutoODE-DSL/a64fa003fa252daaab6a4731df14b9f5bb007dc0/ode_nn/torchdiffeq/_impl/__pycache__/interp.cpython-37.pyc -------------------------------------------------------------------------------- /ode_nn/torchdiffeq/_impl/__pycache__/misc.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Rose-STL-Lab/AutoODE-DSL/a64fa003fa252daaab6a4731df14b9f5bb007dc0/ode_nn/torchdiffeq/_impl/__pycache__/misc.cpython-36.pyc -------------------------------------------------------------------------------- /ode_nn/torchdiffeq/_impl/__pycache__/misc.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Rose-STL-Lab/AutoODE-DSL/a64fa003fa252daaab6a4731df14b9f5bb007dc0/ode_nn/torchdiffeq/_impl/__pycache__/misc.cpython-37.pyc -------------------------------------------------------------------------------- /ode_nn/torchdiffeq/_impl/__pycache__/odeint.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Rose-STL-Lab/AutoODE-DSL/a64fa003fa252daaab6a4731df14b9f5bb007dc0/ode_nn/torchdiffeq/_impl/__pycache__/odeint.cpython-36.pyc -------------------------------------------------------------------------------- /ode_nn/torchdiffeq/_impl/__pycache__/odeint.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Rose-STL-Lab/AutoODE-DSL/a64fa003fa252daaab6a4731df14b9f5bb007dc0/ode_nn/torchdiffeq/_impl/__pycache__/odeint.cpython-37.pyc -------------------------------------------------------------------------------- /ode_nn/torchdiffeq/_impl/__pycache__/rk_common.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Rose-STL-Lab/AutoODE-DSL/a64fa003fa252daaab6a4731df14b9f5bb007dc0/ode_nn/torchdiffeq/_impl/__pycache__/rk_common.cpython-36.pyc -------------------------------------------------------------------------------- /ode_nn/torchdiffeq/_impl/__pycache__/rk_common.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Rose-STL-Lab/AutoODE-DSL/a64fa003fa252daaab6a4731df14b9f5bb007dc0/ode_nn/torchdiffeq/_impl/__pycache__/rk_common.cpython-37.pyc -------------------------------------------------------------------------------- /ode_nn/torchdiffeq/_impl/__pycache__/solvers.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Rose-STL-Lab/AutoODE-DSL/a64fa003fa252daaab6a4731df14b9f5bb007dc0/ode_nn/torchdiffeq/_impl/__pycache__/solvers.cpython-36.pyc -------------------------------------------------------------------------------- /ode_nn/torchdiffeq/_impl/__pycache__/solvers.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Rose-STL-Lab/AutoODE-DSL/a64fa003fa252daaab6a4731df14b9f5bb007dc0/ode_nn/torchdiffeq/_impl/__pycache__/solvers.cpython-37.pyc -------------------------------------------------------------------------------- /ode_nn/torchdiffeq/_impl/__pycache__/tsit5.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Rose-STL-Lab/AutoODE-DSL/a64fa003fa252daaab6a4731df14b9f5bb007dc0/ode_nn/torchdiffeq/_impl/__pycache__/tsit5.cpython-36.pyc -------------------------------------------------------------------------------- /ode_nn/torchdiffeq/_impl/__pycache__/tsit5.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Rose-STL-Lab/AutoODE-DSL/a64fa003fa252daaab6a4731df14b9f5bb007dc0/ode_nn/torchdiffeq/_impl/__pycache__/tsit5.cpython-37.pyc -------------------------------------------------------------------------------- /ode_nn/torchdiffeq/_impl/adams.py: -------------------------------------------------------------------------------- 1 | import collections 2 | import torch 3 | from .solvers import AdaptiveStepsizeODESolver 4 | from .misc import ( 5 | _handle_unused_kwargs, _select_initial_step, _convert_to_tensor, _scaled_dot_product, _is_iterable, 6 | _optimal_step_size, _compute_error_ratio 7 | ) 8 | 9 | _MIN_ORDER = 1 10 | _MAX_ORDER = 12 11 | 12 | gamma_star = [ 13 | 1, -1 / 2, -1 / 12, -1 / 24, -19 / 720, -3 / 160, -863 / 60480, -275 / 24192, -33953 / 3628800, -0.00789255, 14 | -0.00678585, -0.00592406, -0.00523669, -0.0046775, -0.00421495, -0.0038269 15 | ] 16 | 17 | 18 | class _VCABMState(collections.namedtuple('_VCABMState', 'y_n, prev_f, prev_t, next_t, phi, order')): 19 | """Saved state of the variable step size Adams-Bashforth-Moulton solver as described in 20 | 21 | Solving Ordinary Differential Equations I - Nonstiff Problems III.5 22 | by Ernst Hairer, Gerhard Wanner, and Syvert P Norsett. 23 | """ 24 | 25 | 26 | def g_and_explicit_phi(prev_t, next_t, implicit_phi, k): 27 | curr_t = prev_t[0] 28 | dt = next_t - prev_t[0] 29 | 30 | g = torch.empty(k + 1).to(prev_t[0]) 31 | explicit_phi = collections.deque(maxlen=k) 32 | beta = torch.tensor(1).to(prev_t[0]) 33 | 34 | g[0] = 1 35 | c = 1 / torch.arange(1, k + 2).to(prev_t[0]) 36 | explicit_phi.append(implicit_phi[0]) 37 | 38 | for j in range(1, k): 39 | beta = (next_t - prev_t[j - 1]) / (curr_t - prev_t[j]) * beta 40 | beat_cast = beta.to(implicit_phi[j][0]) 41 | explicit_phi.append(tuple(iphi_ * beat_cast for iphi_ in implicit_phi[j])) 42 | 43 | c = c[:-1] - c[1:] if j == 1 else c[:-1] - c[1:] * dt / (next_t - prev_t[j - 1]) 44 | g[j] = c[0] 45 | 46 | c = c[:-1] - c[1:] * dt / (next_t - prev_t[k - 1]) 47 | g[k] = c[0] 48 | 49 | return g, explicit_phi 50 | 51 | 52 | def compute_implicit_phi(explicit_phi, f_n, k): 53 | k = min(len(explicit_phi) + 1, k) 54 | implicit_phi = collections.deque(maxlen=k) 55 | implicit_phi.append(f_n) 56 | for j in range(1, k): 57 | implicit_phi.append(tuple(iphi_ - ephi_ for iphi_, ephi_ in zip(implicit_phi[j - 1], explicit_phi[j - 1]))) 58 | return implicit_phi 59 | 60 | 61 | class VariableCoefficientAdamsBashforth(AdaptiveStepsizeODESolver): 62 | 63 | def __init__( 64 | self, func, y0, rtol, atol, implicit=True, first_step=None, max_order=_MAX_ORDER, safety=0.9, ifactor=10.0, dfactor=0.2, 65 | **unused_kwargs 66 | ): 67 | _handle_unused_kwargs(self, unused_kwargs) 68 | del unused_kwargs 69 | 70 | self.func = func 71 | self.y0 = y0 72 | self.rtol = rtol if _is_iterable(rtol) else [rtol] * len(y0) 73 | self.atol = atol if _is_iterable(atol) else [atol] * len(y0) 74 | self.implicit = implicit 75 | self.first_step = first_step 76 | self.max_order = int(max(_MIN_ORDER, min(max_order, _MAX_ORDER))) 77 | self.safety = _convert_to_tensor(safety, dtype=torch.float64, device=y0[0].device) 78 | self.ifactor = _convert_to_tensor(ifactor, dtype=torch.float64, device=y0[0].device) 79 | self.dfactor = _convert_to_tensor(dfactor, dtype=torch.float64, device=y0[0].device) 80 | 81 | def before_integrate(self, t): 82 | prev_f = collections.deque(maxlen=self.max_order + 1) 83 | prev_t = collections.deque(maxlen=self.max_order + 1) 84 | phi = collections.deque(maxlen=self.max_order) 85 | 86 | t0 = t[0] 87 | f0 = self.func(t0.type_as(self.y0[0]), self.y0) 88 | prev_t.appendleft(t0) 89 | prev_f.appendleft(f0) 90 | phi.appendleft(f0) 91 | if self.first_step is None: 92 | first_step = _select_initial_step(self.func, t[0], self.y0, 2, self.rtol[0], self.atol[0], f0=f0).to(t) 93 | else: 94 | first_step = _select_initial_step(self.func, t[0], self.y0, 2, self.rtol[0], self.atol[0], f0=f0).to(t) 95 | 96 | self.vcabm_state = _VCABMState(self.y0, prev_f, prev_t, next_t=t[0] + first_step, phi=phi, order=1) 97 | 98 | def advance(self, final_t): 99 | final_t = _convert_to_tensor(final_t).to(self.vcabm_state.prev_t[0]) 100 | while final_t > self.vcabm_state.prev_t[0]: 101 | self.vcabm_state = self._adaptive_adams_step(self.vcabm_state, final_t) 102 | assert final_t == self.vcabm_state.prev_t[0] 103 | return self.vcabm_state.y_n 104 | 105 | def _adaptive_adams_step(self, vcabm_state, final_t): 106 | y0, prev_f, prev_t, next_t, prev_phi, order = vcabm_state 107 | if next_t > final_t: 108 | next_t = final_t 109 | dt = (next_t - prev_t[0]) 110 | dt_cast = dt.to(y0[0]) 111 | 112 | # Explicit predictor step. 113 | g, phi = g_and_explicit_phi(prev_t, next_t, prev_phi, order) 114 | g = g.to(y0[0]) 115 | p_next = tuple( 116 | y0_ + _scaled_dot_product(dt_cast, g[:max(1, order - 1)], phi_[:max(1, order - 1)]) 117 | for y0_, phi_ in zip(y0, tuple(zip(*phi))) 118 | ) 119 | 120 | # Update phi to implicit. 121 | next_f0 = self.func(next_t.to(p_next[0]), p_next) 122 | implicit_phi_p = compute_implicit_phi(phi, next_f0, order + 1) 123 | 124 | # Implicit corrector step. 125 | y_next = tuple( 126 | p_next_ + dt_cast * g[order - 1] * iphi_ for p_next_, iphi_ in zip(p_next, implicit_phi_p[order - 1]) 127 | ) 128 | 129 | # Error estimation. 130 | tolerance = tuple( 131 | atol_ + rtol_ * torch.max(torch.abs(y0_), torch.abs(y1_)) 132 | for atol_, rtol_, y0_, y1_ in zip(self.atol, self.rtol, y0, y_next) 133 | ) 134 | local_error = tuple(dt_cast * (g[order] - g[order - 1]) * iphi_ for iphi_ in implicit_phi_p[order]) 135 | error_k = _compute_error_ratio(local_error, tolerance) 136 | accept_step = (torch.tensor(error_k) <= 1).all() 137 | 138 | if not accept_step: 139 | # Retry with adjusted step size if step is rejected. 140 | dt_next = _optimal_step_size(dt, error_k, self.safety, self.ifactor, self.dfactor, order=order) 141 | return _VCABMState(y0, prev_f, prev_t, prev_t[0] + dt_next, prev_phi, order=order) 142 | 143 | # We accept the step. Evaluate f and update phi. 144 | next_f0 = self.func(next_t.to(p_next[0]), y_next) 145 | implicit_phi = compute_implicit_phi(phi, next_f0, order + 2) 146 | 147 | next_order = order 148 | 149 | if len(prev_t) <= 4 or order < 3: 150 | next_order = min(order + 1, 3, self.max_order) 151 | else: 152 | error_km1 = _compute_error_ratio( 153 | tuple(dt_cast * (g[order - 1] - g[order - 2]) * iphi_ for iphi_ in implicit_phi_p[order - 1]), tolerance 154 | ) 155 | error_km2 = _compute_error_ratio( 156 | tuple(dt_cast * (g[order - 2] - g[order - 3]) * iphi_ for iphi_ in implicit_phi_p[order - 2]), tolerance 157 | ) 158 | if min(error_km1 + error_km2) < max(error_k): 159 | next_order = order - 1 160 | elif order < self.max_order: 161 | error_kp1 = _compute_error_ratio( 162 | tuple(dt_cast * gamma_star[order] * iphi_ for iphi_ in implicit_phi_p[order]), tolerance 163 | ) 164 | if max(error_kp1) < max(error_k): 165 | next_order = order + 1 166 | 167 | # Keep step size constant if increasing order. Else use adaptive step size. 168 | dt_next = dt if next_order > order else _optimal_step_size( 169 | dt, error_k, self.safety, self.ifactor, self.dfactor, order=order + 1 170 | ) 171 | 172 | prev_f.appendleft(next_f0) 173 | prev_t.appendleft(next_t) 174 | return _VCABMState(p_next, prev_f, prev_t, next_t + dt_next, implicit_phi, order=next_order) 175 | -------------------------------------------------------------------------------- /ode_nn/torchdiffeq/_impl/adaptive_heun.py: -------------------------------------------------------------------------------- 1 | # Based on https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/integrate 2 | import torch 3 | from .misc import ( 4 | _scaled_dot_product, _convert_to_tensor, _is_finite, _select_initial_step, _handle_unused_kwargs, _is_iterable, 5 | _optimal_step_size, _compute_error_ratio 6 | ) 7 | from .solvers import AdaptiveStepsizeODESolver 8 | from .interp import _interp_fit, _interp_evaluate 9 | from .rk_common import _RungeKuttaState, _ButcherTableau, _runge_kutta_step 10 | 11 | _ADAPTIVE_HEUN_TABLEAU = _ButcherTableau( 12 | alpha=[1.], 13 | beta=[ 14 | [1.], 15 | ], 16 | c_sol=[0.5, 0.5], 17 | c_error=[ 18 | 0.5, 19 | -0.5, 20 | ], 21 | ) 22 | 23 | AH_C_MID = [ 24 | 0.5, 0. 25 | ] 26 | 27 | 28 | def _interp_fit_adaptive_heun(y0, y1, k, dt, tableau=_ADAPTIVE_HEUN_TABLEAU): 29 | """Fit an interpolating polynomial to the results of a Runge-Kutta step.""" 30 | dt = dt.type_as(y0[0]) 31 | y_mid = tuple(y0_ + _scaled_dot_product(dt, AH_C_MID, k_) for y0_, k_ in zip(y0, k)) 32 | f0 = tuple(k_[0] for k_ in k) 33 | f1 = tuple(k_[-1] for k_ in k) 34 | return _interp_fit(y0, y1, y_mid, f0, f1, dt) 35 | 36 | 37 | def _abs_square(x): 38 | return torch.mul(x, x) 39 | 40 | 41 | def _ta_append(list_of_tensors, value): 42 | """Append a value to the end of a list of PyTorch tensors.""" 43 | list_of_tensors.append(value) 44 | return list_of_tensors 45 | 46 | 47 | class AdaptiveHeunSolver(AdaptiveStepsizeODESolver): 48 | 49 | def __init__( 50 | self, func, y0, rtol, atol, first_step=None, safety=0.9, ifactor=10.0, dfactor=0.2, max_num_steps=2**31 - 1, 51 | **unused_kwargs 52 | ): 53 | _handle_unused_kwargs(self, unused_kwargs) 54 | del unused_kwargs 55 | 56 | self.func = func 57 | self.y0 = y0 58 | self.rtol = rtol if _is_iterable(rtol) else [rtol] * len(y0) 59 | self.atol = atol if _is_iterable(atol) else [atol] * len(y0) 60 | self.first_step = first_step 61 | self.safety = _convert_to_tensor(safety, dtype=torch.float64, device=y0[0].device) 62 | self.ifactor = _convert_to_tensor(ifactor, dtype=torch.float64, device=y0[0].device) 63 | self.dfactor = _convert_to_tensor(dfactor, dtype=torch.float64, device=y0[0].device) 64 | self.max_num_steps = _convert_to_tensor(max_num_steps, dtype=torch.int32, device=y0[0].device) 65 | 66 | def before_integrate(self, t): 67 | f0 = self.func(t[0].type_as(self.y0[0]), self.y0) 68 | if self.first_step is None: 69 | first_step = _select_initial_step(self.func, t[0], self.y0, 1, self.rtol[0], self.atol[0], f0=f0).to(t) 70 | else: 71 | first_step = _convert_to_tensor(self.first_step, dtype=t.dtype, device=t.device) 72 | self.rk_state = _RungeKuttaState(self.y0, f0, t[0], t[0], first_step, interp_coeff=[self.y0] * 5) 73 | 74 | def advance(self, next_t): 75 | """Interpolate through the next time point, integrating as necessary.""" 76 | n_steps = 0 77 | while next_t > self.rk_state.t1: 78 | assert n_steps < self.max_num_steps, 'max_num_steps exceeded ({}>={})'.format(n_steps, self.max_num_steps) 79 | self.rk_state = self._adaptive_heun_step(self.rk_state) 80 | n_steps += 1 81 | return _interp_evaluate(self.rk_state.interp_coeff, self.rk_state.t0, self.rk_state.t1, next_t) 82 | 83 | def _adaptive_heun_step(self, rk_state): 84 | """Take an adaptive Runge-Kutta step to integrate the ODE.""" 85 | y0, f0, _, t0, dt, interp_coeff = rk_state 86 | ######################################################## 87 | # Assertions # 88 | ######################################################## 89 | assert t0 + dt > t0, 'underflow in dt {}'.format(dt.item()) 90 | for y0_ in y0: 91 | assert _is_finite(torch.abs(y0_)), 'non-finite values in state `y`: {}'.format(y0_) 92 | y1, f1, y1_error, k = _runge_kutta_step(self.func, y0, f0, t0, dt, tableau=_ADAPTIVE_HEUN_TABLEAU) 93 | 94 | ######################################################## 95 | # Error Ratio # 96 | ######################################################## 97 | mean_sq_error_ratio = _compute_error_ratio(y1_error, atol=self.atol, rtol=self.rtol, y0=y0, y1=y1) 98 | accept_step = (torch.tensor(mean_sq_error_ratio) <= 1).all() 99 | 100 | ######################################################## 101 | # Update RK State # 102 | ######################################################## 103 | y_next = y1 if accept_step else y0 104 | f_next = f1 if accept_step else f0 105 | t_next = t0 + dt if accept_step else t0 106 | interp_coeff = _interp_fit_adaptive_heun(y0, y1, k, dt) if accept_step else interp_coeff 107 | dt_next = _optimal_step_size( 108 | dt, mean_sq_error_ratio, safety=self.safety, ifactor=self.ifactor, dfactor=self.dfactor, order=5 109 | ) 110 | rk_state = _RungeKuttaState(y_next, f_next, t0, t_next, dt_next, interp_coeff) 111 | return rk_state 112 | -------------------------------------------------------------------------------- /ode_nn/torchdiffeq/_impl/adjoint.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from . import odeint 4 | from .misc import _flatten, _flatten_convert_none_to_zeros 5 | 6 | 7 | class OdeintAdjointMethod(torch.autograd.Function): 8 | 9 | @staticmethod 10 | def forward(ctx, *args): 11 | assert len(args) >= 8, 'Internal error: all arguments required.' 12 | (y0, func, t, flat_params, rtol, atol, method, options, adjoint_rtol, adjoint_atol, adjoint_method, 13 | adjoint_options) = (args[:-11], args[-11], args[-10], args[-9], args[-8], args[-7], args[-6], args[-5], 14 | args[-4], args[-3], args[-2], args[-1]) 15 | 16 | (ctx.func, ctx.adjoint_rtol, ctx.adjoint_atol, ctx.adjoint_method, 17 | ctx.adjoint_options) = func, adjoint_rtol, adjoint_atol, adjoint_method, adjoint_options 18 | 19 | with torch.no_grad(): 20 | ans = odeint(func, y0, t, rtol=rtol, atol=atol, method=method, options=options) 21 | ctx.save_for_backward(t, flat_params, *ans) 22 | return ans 23 | 24 | @staticmethod 25 | def backward(ctx, *grad_output): 26 | t, flat_params, *ans = ctx.saved_tensors 27 | ans = tuple(ans) 28 | (func, adjoint_rtol, adjoint_atol, adjoint_method, 29 | adjoint_options) = ctx.func, ctx.adjoint_rtol, ctx.adjoint_atol, ctx.adjoint_method, ctx.adjoint_options 30 | n_tensors = len(ans) 31 | f_params = tuple(func.parameters()) 32 | 33 | # TODO: use a nn.Module and call odeint_adjoint to implement higher order derivatives. 34 | def augmented_dynamics(t, y_aug): 35 | # Dynamics of the original system augmented with 36 | # the adjoint wrt y, and an integrator wrt t and args. 37 | y, adj_y = y_aug[:n_tensors], y_aug[n_tensors:2 * n_tensors] # Ignore adj_time and adj_params. 38 | 39 | with torch.set_grad_enabled(True): 40 | t = t.to(y[0].device).detach().requires_grad_(True) 41 | y = tuple(y_.detach().requires_grad_(True) for y_ in y) 42 | func_eval = func(t, y) 43 | vjp_t, *vjp_y_and_params = torch.autograd.grad( 44 | func_eval, (t,) + y + f_params, 45 | tuple(-adj_y_ for adj_y_ in adj_y), allow_unused=True, retain_graph=True 46 | ) 47 | vjp_y = vjp_y_and_params[:n_tensors] 48 | vjp_params = vjp_y_and_params[n_tensors:] 49 | 50 | # autograd.grad returns None if no gradient, set to zero. 51 | vjp_t = torch.zeros_like(t) if vjp_t is None else vjp_t 52 | vjp_y = tuple(torch.zeros_like(y_) if vjp_y_ is None else vjp_y_ for vjp_y_, y_ in zip(vjp_y, y)) 53 | vjp_params = _flatten_convert_none_to_zeros(vjp_params, f_params) 54 | 55 | if len(f_params) == 0: 56 | vjp_params = torch.tensor(0.).to(vjp_y[0]) 57 | return (*func_eval, *vjp_y, vjp_t, vjp_params) 58 | 59 | T = ans[0].shape[0] 60 | with torch.no_grad(): 61 | adj_y = tuple(grad_output_[-1] for grad_output_ in grad_output) 62 | adj_params = torch.zeros_like(flat_params) 63 | adj_time = torch.tensor(0.).to(t) 64 | time_vjps = [] 65 | for i in range(T - 1, 0, -1): 66 | 67 | ans_i = tuple(ans_[i] for ans_ in ans) 68 | grad_output_i = tuple(grad_output_[i] for grad_output_ in grad_output) 69 | func_i = func(t[i], ans_i) 70 | 71 | # Compute the effect of moving the current time measurement point. 72 | dLd_cur_t = sum( 73 | torch.dot(func_i_.reshape(-1), grad_output_i_.reshape(-1)).reshape(1) 74 | for func_i_, grad_output_i_ in zip(func_i, grad_output_i) 75 | ) 76 | adj_time = adj_time - dLd_cur_t 77 | time_vjps.append(dLd_cur_t) 78 | 79 | # Run the augmented system backwards in time. 80 | if adj_params.numel() == 0: 81 | adj_params = torch.tensor(0.).to(adj_y[0]) 82 | aug_y0 = (*ans_i, *adj_y, adj_time, adj_params) 83 | aug_ans = odeint( 84 | augmented_dynamics, aug_y0, 85 | torch.tensor([t[i], t[i - 1]]), 86 | rtol=adjoint_rtol, atol=adjoint_atol, method=adjoint_method, options=adjoint_options 87 | ) 88 | 89 | # Unpack aug_ans. 90 | adj_y = aug_ans[n_tensors:2 * n_tensors] 91 | adj_time = aug_ans[2 * n_tensors] 92 | adj_params = aug_ans[2 * n_tensors + 1] 93 | 94 | adj_y = tuple(adj_y_[1] if len(adj_y_) > 0 else adj_y_ for adj_y_ in adj_y) 95 | if len(adj_time) > 0: adj_time = adj_time[1] 96 | if len(adj_params) > 0: adj_params = adj_params[1] 97 | 98 | adj_y = tuple(adj_y_ + grad_output_[i - 1] for adj_y_, grad_output_ in zip(adj_y, grad_output)) 99 | 100 | del aug_y0, aug_ans 101 | 102 | time_vjps.append(adj_time) 103 | time_vjps = torch.cat(time_vjps[::-1]) 104 | 105 | return (*adj_y, None, time_vjps, adj_params, None, None, None, None, None, None, None, None) 106 | 107 | 108 | def odeint_adjoint(func, y0, t, rtol=1e-6, atol=1e-12, method=None, options=None, adjoint_rtol=None, adjoint_atol=None, 109 | adjoint_method=None, adjoint_options=None): 110 | 111 | # We need this in order to access the variables inside this module, 112 | # since we have no other way of getting variables along the execution path. 113 | if not isinstance(func, nn.Module): 114 | raise ValueError('func is required to be an instance of nn.Module.') 115 | 116 | if adjoint_rtol is None: 117 | adjoint_rtol = rtol 118 | if adjoint_atol is None: 119 | adjoint_atol = atol 120 | if adjoint_method is None: 121 | adjoint_method = method 122 | if adjoint_options is None: 123 | adjoint_options = options 124 | 125 | tensor_input = False 126 | if torch.is_tensor(y0): 127 | 128 | class TupleFunc(nn.Module): 129 | 130 | def __init__(self, base_func): 131 | super(TupleFunc, self).__init__() 132 | self.base_func = base_func 133 | 134 | def forward(self, t, y): 135 | return (self.base_func(t, y[0]),) 136 | 137 | tensor_input = True 138 | y0 = (y0,) 139 | func = TupleFunc(func) 140 | 141 | flat_params = _flatten(func.parameters()) 142 | ys = OdeintAdjointMethod.apply(*y0, func, t, flat_params, rtol, atol, method, options, adjoint_rtol, adjoint_atol, 143 | adjoint_method, adjoint_options) 144 | 145 | if tensor_input: 146 | ys = ys[0] 147 | return ys 148 | -------------------------------------------------------------------------------- /ode_nn/torchdiffeq/_impl/bosh3.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from .misc import ( 3 | _scaled_dot_product, _convert_to_tensor, _is_finite, _select_initial_step, _handle_unused_kwargs, _is_iterable, 4 | _optimal_step_size, _compute_error_ratio 5 | ) 6 | from .solvers import AdaptiveStepsizeODESolver 7 | from .interp import _interp_fit, _interp_evaluate 8 | from .rk_common import _RungeKuttaState, _ButcherTableau, _runge_kutta_step 9 | 10 | _BOGACKI_SHAMPINE_TABLEAU = _ButcherTableau( 11 | alpha=[1/2, 3/4, 1.], 12 | beta=[ 13 | [1/2], 14 | [0., 3/4], 15 | [2/9, 1/3, 4/9] 16 | ], 17 | c_sol=[2/9, 1/3, 4/9, 0.], 18 | c_error=[2/9-7/24, 1/3-1/4, 4/9-1/3, -1/8], 19 | ) 20 | 21 | BS_C_MID = [ 0., 0.5, 0., 0. ] 22 | 23 | 24 | def _interp_fit_bosh3(y0, y1, k, dt): 25 | """Fit an interpolating polynomial to the results of a Runge-Kutta step.""" 26 | dt = dt.type_as(y0[0]) 27 | y_mid = tuple(y0_ + _scaled_dot_product(dt, BS_C_MID, k_) for y0_, k_ in zip(y0, k)) 28 | f0 = tuple(k_[0] for k_ in k) 29 | f1 = tuple(k_[-1] for k_ in k) 30 | return _interp_fit(y0, y1, y_mid, f0, f1, dt) 31 | 32 | 33 | 34 | class Bosh3Solver(AdaptiveStepsizeODESolver): 35 | 36 | def __init__( 37 | self, func, y0, rtol, atol, first_step=None, safety=0.9, ifactor=10.0, dfactor=0.2, max_num_steps=2**31 - 1, 38 | **unused_kwargs 39 | ): 40 | _handle_unused_kwargs(self, unused_kwargs) 41 | del unused_kwargs 42 | 43 | self.func = func 44 | self.y0 = y0 45 | self.rtol = rtol if _is_iterable(rtol) else [rtol] * len(y0) 46 | self.atol = atol if _is_iterable(atol) else [atol] * len(y0) 47 | self.first_step = first_step 48 | self.safety = _convert_to_tensor(safety, dtype=torch.float64, device=y0[0].device) 49 | self.ifactor = _convert_to_tensor(ifactor, dtype=torch.float64, device=y0[0].device) 50 | self.dfactor = _convert_to_tensor(dfactor, dtype=torch.float64, device=y0[0].device) 51 | self.max_num_steps = _convert_to_tensor(max_num_steps, dtype=torch.int32, device=y0[0].device) 52 | 53 | def before_integrate(self, t): 54 | f0 = self.func(t[0].type_as(self.y0[0]), self.y0) 55 | if self.first_step is None: 56 | first_step = _select_initial_step(self.func, t[0], self.y0, 2, self.rtol[0], self.atol[0], f0=f0).to(t) 57 | else: 58 | first_step = _convert_to_tensor(self.first_step, dtype=t.dtype, device=t.device) 59 | self.rk_state = _RungeKuttaState(self.y0, f0, t[0], t[0], first_step, interp_coeff=[self.y0] * 5) 60 | 61 | def advance(self, next_t): 62 | """Interpolate through the next time point, integrating as necessary.""" 63 | n_steps = 0 64 | while next_t > self.rk_state.t1: 65 | assert n_steps < self.max_num_steps, 'max_num_steps exceeded ({}>={})'.format(n_steps, self.max_num_steps) 66 | self.rk_state = self._adaptive_bosh3_step(self.rk_state) 67 | n_steps += 1 68 | return _interp_evaluate(self.rk_state.interp_coeff, self.rk_state.t0, self.rk_state.t1, next_t) 69 | 70 | def _adaptive_bosh3_step(self, rk_state): 71 | """Take an adaptive Runge-Kutta step to integrate the ODE.""" 72 | y0, f0, _, t0, dt, interp_coeff = rk_state 73 | ######################################################## 74 | # Assertions # 75 | ######################################################## 76 | assert t0 + dt > t0, 'underflow in dt {}'.format(dt.item()) 77 | for y0_ in y0: 78 | assert _is_finite(torch.abs(y0_)), 'non-finite values in state `y`: {}'.format(y0_) 79 | y1, f1, y1_error, k = _runge_kutta_step(self.func, y0, f0, t0, dt, tableau=_BOGACKI_SHAMPINE_TABLEAU) 80 | 81 | ######################################################## 82 | # Error Ratio # 83 | ######################################################## 84 | mean_sq_error_ratio = _compute_error_ratio(y1_error, atol=self.atol, rtol=self.rtol, y0=y0, y1=y1) 85 | accept_step = (torch.tensor(mean_sq_error_ratio) <= 1).all() 86 | 87 | ######################################################## 88 | # Update RK State # 89 | ######################################################## 90 | y_next = y1 if accept_step else y0 91 | f_next = f1 if accept_step else f0 92 | t_next = t0 + dt if accept_step else t0 93 | interp_coeff = _interp_fit_bosh3(y0, y1, k, dt) if accept_step else interp_coeff 94 | dt_next = _optimal_step_size( 95 | dt, mean_sq_error_ratio, safety=self.safety, ifactor=self.ifactor, dfactor=self.dfactor, order=3 96 | ) 97 | rk_state = _RungeKuttaState(y_next, f_next, t0, t_next, dt_next, interp_coeff) 98 | return rk_state 99 | -------------------------------------------------------------------------------- /ode_nn/torchdiffeq/_impl/dopri5.py: -------------------------------------------------------------------------------- 1 | # Based on https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/integrate 2 | import torch 3 | from .misc import ( 4 | _scaled_dot_product, _convert_to_tensor, _is_finite, _select_initial_step, _handle_unused_kwargs, _is_iterable, 5 | _optimal_step_size, _compute_error_ratio 6 | ) 7 | from .solvers import AdaptiveStepsizeODESolver 8 | from .interp import _interp_fit, _interp_evaluate 9 | from .rk_common import _RungeKuttaState, _ButcherTableau, _runge_kutta_step 10 | 11 | _DORMAND_PRINCE_SHAMPINE_TABLEAU = _ButcherTableau( 12 | alpha=[1 / 5, 3 / 10, 4 / 5, 8 / 9, 1., 1.], 13 | beta=[ 14 | [1 / 5], 15 | [3 / 40, 9 / 40], 16 | [44 / 45, -56 / 15, 32 / 9], 17 | [19372 / 6561, -25360 / 2187, 64448 / 6561, -212 / 729], 18 | [9017 / 3168, -355 / 33, 46732 / 5247, 49 / 176, -5103 / 18656], 19 | [35 / 384, 0, 500 / 1113, 125 / 192, -2187 / 6784, 11 / 84], 20 | ], 21 | c_sol=[35 / 384, 0, 500 / 1113, 125 / 192, -2187 / 6784, 11 / 84, 0], 22 | c_error=[ 23 | 35 / 384 - 1951 / 21600, 24 | 0, 25 | 500 / 1113 - 22642 / 50085, 26 | 125 / 192 - 451 / 720, 27 | -2187 / 6784 - -12231 / 42400, 28 | 11 / 84 - 649 / 6300, 29 | -1. / 60., 30 | ], 31 | ) 32 | 33 | DPS_C_MID = [ 34 | 6025192743 / 30085553152 / 2, 0, 51252292925 / 65400821598 / 2, -2691868925 / 45128329728 / 2, 35 | 187940372067 / 1594534317056 / 2, -1776094331 / 19743644256 / 2, 11237099 / 235043384 / 2 36 | ] 37 | 38 | 39 | def _interp_fit_dopri5(y0, y1, k, dt, tableau=_DORMAND_PRINCE_SHAMPINE_TABLEAU): 40 | """Fit an interpolating polynomial to the results of a Runge-Kutta step.""" 41 | dt = dt.type_as(y0[0]) 42 | y_mid = tuple(y0_ + _scaled_dot_product(dt, DPS_C_MID, k_) for y0_, k_ in zip(y0, k)) 43 | f0 = tuple(k_[0] for k_ in k) 44 | f1 = tuple(k_[-1] for k_ in k) 45 | return _interp_fit(y0, y1, y_mid, f0, f1, dt) 46 | 47 | 48 | def _abs_square(x): 49 | return torch.mul(x, x) 50 | 51 | 52 | def _ta_append(list_of_tensors, value): 53 | """Append a value to the end of a list of PyTorch tensors.""" 54 | list_of_tensors.append(value) 55 | return list_of_tensors 56 | 57 | 58 | class Dopri5Solver(AdaptiveStepsizeODESolver): 59 | 60 | def __init__( 61 | self, func, y0, rtol, atol, first_step=None, safety=0.9, ifactor=10.0, dfactor=0.2, max_num_steps=2**31 - 1, 62 | **unused_kwargs 63 | ): 64 | _handle_unused_kwargs(self, unused_kwargs) 65 | del unused_kwargs 66 | 67 | self.func = func 68 | self.y0 = y0 69 | self.rtol = rtol if _is_iterable(rtol) else [rtol] * len(y0) 70 | self.atol = atol if _is_iterable(atol) else [atol] * len(y0) 71 | self.first_step = first_step 72 | self.safety = _convert_to_tensor(safety, dtype=torch.float64, device=y0[0].device) 73 | self.ifactor = _convert_to_tensor(ifactor, dtype=torch.float64, device=y0[0].device) 74 | self.dfactor = _convert_to_tensor(dfactor, dtype=torch.float64, device=y0[0].device) 75 | self.max_num_steps = _convert_to_tensor(max_num_steps, dtype=torch.int32, device=y0[0].device) 76 | 77 | def before_integrate(self, t): 78 | f0 = self.func(t[0].type_as(self.y0[0]), self.y0) 79 | if self.first_step is None: 80 | first_step = _select_initial_step(self.func, t[0], self.y0, 4, self.rtol[0], self.atol[0], f0=f0).to(t) 81 | else: 82 | first_step = _convert_to_tensor(self.first_step, dtype=t.dtype, device=t.device) 83 | self.rk_state = _RungeKuttaState(self.y0, f0, t[0], t[0], first_step, interp_coeff=[self.y0] * 5) 84 | 85 | def advance(self, next_t): 86 | """Interpolate through the next time point, integrating as necessary.""" 87 | n_steps = 0 88 | while next_t > self.rk_state.t1: 89 | assert n_steps < self.max_num_steps, 'max_num_steps exceeded ({}>={})'.format(n_steps, self.max_num_steps) 90 | self.rk_state = self._adaptive_dopri5_step(self.rk_state) 91 | n_steps += 1 92 | return _interp_evaluate(self.rk_state.interp_coeff, self.rk_state.t0, self.rk_state.t1, next_t) 93 | 94 | def _adaptive_dopri5_step(self, rk_state): 95 | """Take an adaptive Runge-Kutta step to integrate the ODE.""" 96 | y0, f0, _, t0, dt, interp_coeff = rk_state 97 | ######################################################## 98 | # Assertions # 99 | ######################################################## 100 | assert t0 + dt > t0, 'underflow in dt {}'.format(dt.item()) 101 | for y0_ in y0: 102 | assert _is_finite(torch.abs(y0_)), 'non-finite values in state `y`: {}'.format(y0_) 103 | y1, f1, y1_error, k = _runge_kutta_step(self.func, y0, f0, t0, dt, tableau=_DORMAND_PRINCE_SHAMPINE_TABLEAU) 104 | 105 | ######################################################## 106 | # Error Ratio # 107 | ######################################################## 108 | mean_sq_error_ratio = _compute_error_ratio(y1_error, atol=self.atol, rtol=self.rtol, y0=y0, y1=y1) 109 | accept_step = (torch.tensor(mean_sq_error_ratio) <= 1).all() 110 | 111 | ######################################################## 112 | # Update RK State # 113 | ######################################################## 114 | y_next = y1 if accept_step else y0 115 | f_next = f1 if accept_step else f0 116 | t_next = t0 + dt if accept_step else t0 117 | interp_coeff = _interp_fit_dopri5(y0, y1, k, dt) if accept_step else interp_coeff 118 | dt_next = _optimal_step_size( 119 | dt, mean_sq_error_ratio, safety=self.safety, ifactor=self.ifactor, dfactor=self.dfactor, order=5 120 | ) 121 | rk_state = _RungeKuttaState(y_next, f_next, t0, t_next, dt_next, interp_coeff) 122 | return rk_state 123 | -------------------------------------------------------------------------------- /ode_nn/torchdiffeq/_impl/dopri8.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from .misc import ( 3 | _scaled_dot_product, _convert_to_tensor, _is_finite, _select_initial_step, _handle_unused_kwargs, _is_iterable, 4 | _optimal_step_size, _compute_error_ratio 5 | ) 6 | from .solvers import AdaptiveStepsizeODESolver 7 | from .interp import _interp_fit, _interp_evaluate 8 | from .rk_common import _RungeKuttaState, _ButcherTableau, _runge_kutta_step 9 | from . import dopri8_coefficients 10 | 11 | 12 | 13 | _DOPRI8_TABLEAU = _ButcherTableau(alpha = dopri8_coefficients.A, 14 | beta = dopri8_coefficients.B, 15 | c_sol = dopri8_coefficients.C_sol, 16 | c_error = dopri8_coefficients.C_err, 17 | ) 18 | 19 | c_mid = dopri8_coefficients.C_mid 20 | 21 | def _interp_fit_dopri8(y0, y1, k, dt, tableau=_DOPRI8_TABLEAU): 22 | """Fit an interpolating polynomial to the results of a Runge-Kutta step.""" 23 | dt = dt.type_as(y0[0]) 24 | y_mid = tuple(y0_ + _scaled_dot_product(dt, c_mid, k_) for y0_, k_ in zip(y0, k)) 25 | f0 = tuple(k_[0] for k_ in k) 26 | f1 = tuple(k_[-1] for k_ in k) 27 | return _interp_fit(y0, y1, y_mid, f0, f1, dt) 28 | 29 | def _abs_square(x): 30 | return torch.mul(x, x) 31 | 32 | 33 | def _ta_append(list_of_tensors, value): 34 | """Append a value to the end of a list of PyTorch tensors.""" 35 | list_of_tensors.append(value) 36 | return list_of_tensors 37 | 38 | 39 | class Dopri8Solver(AdaptiveStepsizeODESolver): 40 | 41 | def __init__( 42 | self, func, y0, rtol, atol, first_step=None, safety=0.9, ifactor=10.0, dfactor=0.2, max_num_steps=2**31 - 1, 43 | **unused_kwargs 44 | ): 45 | _handle_unused_kwargs(self, unused_kwargs) 46 | del unused_kwargs 47 | 48 | self.func = func 49 | self.y0 = y0 50 | self.rtol = rtol if _is_iterable(rtol) else [rtol] * len(y0) 51 | self.atol = atol if _is_iterable(atol) else [atol] * len(y0) 52 | self.first_step = first_step 53 | self.safety = _convert_to_tensor(safety, dtype=torch.float64, device=y0[0].device) 54 | self.ifactor = _convert_to_tensor(ifactor, dtype=torch.float64, device=y0[0].device) 55 | self.dfactor = _convert_to_tensor(dfactor, dtype=torch.float64, device=y0[0].device) 56 | self.max_num_steps = _convert_to_tensor(max_num_steps, dtype=torch.int32, device=y0[0].device) 57 | 58 | def before_integrate(self, t): 59 | f0 = self.func(t[0].type_as(self.y0[0]), self.y0) 60 | if self.first_step is None: 61 | first_step = _select_initial_step(self.func, t[0], self.y0, 7, self.rtol[0], self.atol[0], f0=f0).to(t) 62 | else: 63 | first_step = _convert_to_tensor(self.first_step, dtype=t.dtype, device=t.device) 64 | self.rk_state = _RungeKuttaState(self.y0, f0, t[0], t[0], first_step, interp_coeff=[self.y0] * 5) 65 | 66 | def advance(self, next_t): 67 | """Interpolate through the next time point, integrating as necessary.""" 68 | n_steps = 0 69 | while next_t > self.rk_state.t1: 70 | assert n_steps < self.max_num_steps, 'max_num_steps exceeded ({}>={})'.format(n_steps, self.max_num_steps) 71 | self.rk_state = self._adaptive_dopri8_step(self.rk_state) 72 | n_steps += 1 73 | return _interp_evaluate(self.rk_state.interp_coeff, self.rk_state.t0, self.rk_state.t1, next_t) 74 | 75 | def _adaptive_dopri8_step(self, rk_state): 76 | """Take an adaptive Runge-Kutta step to integrate the ODE.""" 77 | y0, f0, _, t0, dt, interp_coeff = rk_state 78 | ######################################################## 79 | # Assertions # 80 | ######################################################## 81 | assert t0 + dt > t0, 'underflow in dt {}'.format(dt.item()) 82 | for y0_ in y0: 83 | assert _is_finite(torch.abs(y0_)), 'non-finite values in state `y`: {}'.format(y0_) 84 | y1, f1, y1_error, k = _runge_kutta_step(self.func, y0, f0, t0, dt, tableau=_DOPRI8_TABLEAU) 85 | 86 | ######################################################## 87 | # Error Ratio # 88 | ######################################################## 89 | mean_sq_error_ratio = _compute_error_ratio(y1_error, atol=self.atol, rtol=self.rtol, y0=y0, y1=y1) 90 | accept_step = (torch.tensor(mean_sq_error_ratio) <= 1).all() 91 | 92 | ######################################################## 93 | # Update RK State # 94 | ######################################################## 95 | y_next = y1 if accept_step else y0 96 | f_next = f1 if accept_step else f0 97 | t_next = t0 + dt if accept_step else t0 98 | interp_coeff = _interp_fit_dopri8(y0, y1, k, dt) if accept_step else interp_coeff 99 | dt_next = _optimal_step_size( 100 | dt, mean_sq_error_ratio, safety=self.safety, ifactor=self.ifactor, dfactor=self.dfactor, order=8 101 | ) 102 | rk_state = _RungeKuttaState(y_next, f_next, t0, t_next, dt_next, interp_coeff) 103 | return rk_state 104 | -------------------------------------------------------------------------------- /ode_nn/torchdiffeq/_impl/dopri8_coefficients.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | A = [ 1/18, 1/12, 1/8, 5/16, 3/8, 59/400, 93/200, 5490023248/9719169821, 13/20, 1201146811/1299019798, 1, 1, 1] 4 | 5 | B = [ 6 | [1/18], 7 | 8 | [1/48, 1/16], 9 | 10 | [1/32, 0, 3/32], 11 | 12 | [5/16, 0, -75/64, 75/64], 13 | 14 | [3/80, 0, 0, 3/16, 3/20], 15 | 16 | [29443841/614563906, 0, 0, 77736538/692538347, -28693883/1125000000, 23124283/1800000000], 17 | 18 | [16016141/946692911, 0, 0, 61564180/158732637, 22789713/633445777, 545815736/2771057229, -180193667/1043307555], 19 | 20 | [39632708/573591083, 0, 0, -433636366/683701615, -421739975/2616292301, 100302831/723423059, 790204164/839813087, 800635310/3783071287], 21 | 22 | [246121993/1340847787, 0, 0, -37695042795/15268766246, -309121744/1061227803, -12992083/490766935, 6005943493/2108947869, 393006217/1396673457, 123872331/1001029789], 23 | 24 | [-1028468189/846180014, 0, 0, 8478235783/508512852, 1311729495/1432422823, -10304129995/1701304382, -48777925059/3047939560, 15336726248/1032824649, -45442868181/3398467696, 3065993473/597172653], 25 | 26 | [185892177/718116043, 0, 0, -3185094517/667107341, -477755414/1098053517, -703635378/230739211, 5731566787/1027545527, 5232866602/850066563, -4093664535/808688257, 3962137247/1805957418, 65686358/487910083], 27 | 28 | [403863854/491063109, 0, 0, -5068492393/434740067, -411421997/543043805, 652783627/914296604, 11173962825/925320556, -13158990841/6184727034, 3936647629/1978049680, -160528059/685178525, 248638103/1413531060, 0], 29 | 30 | [ 14005451/335480064, 0, 0, 0, 0, -59238493/1068277825, 181606767/758867731, 561292985/797845732, -1041891430/1371343529, 760417239/1151165299, 118820643/751138087, -528747749/2220607170, 1/4] 31 | ] 32 | 33 | C_sol = [ 14005451/335480064, 0, 0, 0, 0, -59238493/1068277825, 181606767/758867731, 561292985/797845732, -1041891430/1371343529, 760417239/1151165299, 118820643/751138087, -528747749/2220607170, 1/4, 0] 34 | 35 | C_err = [ 14005451/335480064 - 13451932/455176623, 0, 0, 0, 0, -59238493/1068277825 - -808719846/976000145, 181606767/758867731 - 1757004468/5645159321, 561292985/797845732 - 656045339/265891186, -1041891430/1371343529 - -3867574721/1518517206, 760417239/1151165299 - 465885868/322736535, 118820643/751138087 - 53011238/667516719, -528747749/2220607170 - 2/45, 1/4, 0] 36 | 37 | h = 1/2 38 | 39 | C_mid = np.zeros(14) 40 | 41 | C_mid[0] = (- 6.3448349392860401388*(h**5) + 22.1396504998094068976*(h**4) - 30.0610568289666450593*(h**3) + 19.9990069333683970610*(h**2) - 6.6910181737837595697*h + 1.0) / (1/h) 42 | 43 | C_mid[5] = (- 39.6107919852202505218*(h**5) + 116.4422149550342161651*(h**4) - 121.4999627731334642623*(h**3) + 52.2273532792945524050*(h**2) - 7.6142658045872677172*h) / (1/h) 44 | 45 | C_mid[6] = (20.3761213808791436958*(h**5) - 67.1451318825957197185*(h**4) + 83.1721004639847717481*(h**3) - 46.8919164181093621583*(h**2) + 10.7281392630428866124*h) / (1/h) 46 | 47 | C_mid[7] = (7.3347098826795362023*(h**5) - 16.5672243527496524646*(h**4) + 9.5724507555993664382*(h**3) - 0.1890893225010595467*(h**2) + 0.5526637063753648783*h) / (1/h) 48 | 49 | C_mid[8] = (32.8801774352459155182*(h**5) - 89.9916014847245016028*(h**4) + 87.8406057677205645007*(h**3) - 35.7075975946222072821*(h**2) + 4.2186562625665153803*h) / (1/h) 50 | 51 | C_mid[9] = (- 10.1588990526426760954*(h**5) + 22.6237489648532849093*(h**4) - 17.4152107770762969005*(h**3) + 6.2736448083240352160*(h**2) - 0.6627209125361597559*h) / (1/h) 52 | 53 | C_mid[10] = (- 12.5401268098782561200*(h**5) + 32.2362340167355370113*(h**4) - 28.5903289514790976966*(h**3) + 10.3160881272450748458*(h**2) - 1.2636789001135462218*h) / (1/h) 54 | 55 | C_mid[11] = (29.5553001484516038033*(h**5) - 82.1020315488359848644*(h**4) + 81.6630950584341412934*(h**3) - 34.7650769866611817349*(h**2) + 5.4106037898590422230*h) / (1/h) 56 | 57 | C_mid[12] = (- 41.7923486424390588923*(h**5) + 116.2662185791119533462*(h**4) - 114.9375291377009418170*(h**3) + 47.7457971078225540396*(h**2) - 7.0321379067945741781*h) / (1/h) 58 | 59 | C_mid[13] = (20.3006925822100825485*(h**5) - 53.9020777466385396792*(h**4) + 50.2558364226176017553*(h**3) - 19.0082099341608028453*(h**2) + 2.3537586759714983486*h) / (1/h) 60 | 61 | C_mid = C_mid.tolist() 62 | -------------------------------------------------------------------------------- /ode_nn/torchdiffeq/_impl/fixed_adams.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import collections 3 | from .solvers import FixedGridODESolver 4 | from .misc import _scaled_dot_product, _has_converged 5 | from . import rk_common 6 | 7 | _BASHFORTH_COEFFICIENTS = [ 8 | [], # order 0 9 | [11], 10 | [3, -1], 11 | [23, -16, 5], 12 | [55, -59, 37, -9], 13 | [1901, -2774, 2616, -1274, 251], 14 | [4277, -7923, 9982, -7298, 2877, -475], 15 | [198721, -447288, 705549, -688256, 407139, -134472, 19087], 16 | [434241, -1152169, 2183877, -2664477, 2102243, -1041723, 295767, -36799], 17 | [14097247, -43125206, 95476786, -139855262, 137968480, -91172642, 38833486, -9664106, 1070017], 18 | [30277247, -104995189, 265932680, -454661776, 538363838, -444772162, 252618224, -94307320, 20884811, -2082753], 19 | [ 20 | 2132509567, -8271795124, 23591063805, -46113029016, 63716378958, -63176201472, 44857168434, -22329634920, 21 | 7417904451, -1479574348, 134211265 22 | ], 23 | [ 24 | 4527766399, -19433810163, 61633227185, -135579356757, 214139355366, -247741639374, 211103573298, -131365867290, 25 | 58189107627, -17410248271, 3158642445, -262747265 26 | ], 27 | [ 28 | 13064406523627, -61497552797274, 214696591002612, -524924579905150, 932884546055895, -1233589244941764, 29 | 1226443086129408, -915883387152444, 507140369728425, -202322913738370, 55060974662412, -9160551085734, 30 | 703604254357 31 | ], 32 | [ 33 | 27511554976875, -140970750679621, 537247052515662, -1445313351681906, 2854429571790805, -4246767353305755, 34 | 4825671323488452, -4204551925534524, 2793869602879077, -1393306307155755, 505586141196430, -126174972681906, 35 | 19382853593787, -1382741929621 36 | ], 37 | [ 38 | 173233498598849, -960122866404112, 3966421670215481, -11643637530577472, 25298910337081429, -41825269932507728, 39 | 53471026659940509, -53246738660646912, 41280216336284259, -24704503655607728, 11205849753515179, 40 | -3728807256577472, 859236476684231, -122594813904112, 8164168737599 41 | ], 42 | [ 43 | 362555126427073, -2161567671248849, 9622096909515337, -30607373860520569, 72558117072259733, 44 | -131963191940828581, 187463140112902893, -210020588912321949, 186087544263596643, -129930094104237331, 45 | 70724351582843483, -29417910911251819, 9038571752734087, -1934443196892599, 257650275915823, -16088129229375 46 | ], 47 | [ 48 | 192996103681340479, -1231887339593444974, 5878428128276811750, -20141834622844109630, 51733880057282977010, 49 | -102651404730855807942, 160414858999474733422, -199694296833704562550, 199061418623907202560, 50 | -158848144481581407370, 100878076849144434322, -50353311405771659322, 19338911944324897550, 51 | -5518639984393844930, 1102560345141059610, -137692773163513234, 8092989203533249 52 | ], 53 | [ 54 | 401972381695456831, -2735437642844079789, 13930159965811142228, -51150187791975812900, 141500575026572531760, 55 | -304188128232928718008, 518600355541383671092, -710171024091234303204, 786600875277595877750, 56 | -706174326992944287370, 512538584122114046748, -298477260353977522892, 137563142659866897224, 57 | -49070094880794267600, 13071639236569712860, -2448689255584545196, 287848942064256339, -15980174332775873 58 | ], 59 | [ 60 | 333374427829017307697, -2409687649238345289684, 13044139139831833251471, -51099831122607588046344, 61 | 151474888613495715415020, -350702929608291455167896, 647758157491921902292692, -967713746544629658690408, 62 | 1179078743786280451953222, -1176161829956768365219840, 960377035444205950813626, -639182123082298748001432, 63 | 343690461612471516746028, -147118738993288163742312, 48988597853073465932820, -12236035290567356418552, 64 | 2157574942881818312049, -239560589366324764716, 12600467236042756559 65 | ], 66 | [ 67 | 691668239157222107697, -5292843584961252933125, 30349492858024727686755, -126346544855927856134295, 68 | 399537307669842150996468, -991168450545135070835076, 1971629028083798845750380, -3191065388846318679544380, 69 | 4241614331208149947151790, -4654326468801478894406214, 4222756879776354065593786, -3161821089800186539248210, 70 | 1943018818982002395655620, -970350191086531368649620, 387739787034699092364924, -121059601023985433003532, 71 | 28462032496476316665705, -4740335757093710713245, 498669220956647866875, -24919383499187492303 72 | ], 73 | ] 74 | 75 | _MOULTON_COEFFICIENTS = [ 76 | [], # order 0 77 | [1], 78 | [1, 1], 79 | [5, 8, -1], 80 | [9, 19, -5, 1], 81 | [251, 646, -264, 106, -19], 82 | [475, 1427, -798, 482, -173, 27], 83 | [19087, 65112, -46461, 37504, -20211, 6312, -863], 84 | [36799, 139849, -121797, 123133, -88547, 41499, -11351, 1375], 85 | [1070017, 4467094, -4604594, 5595358, -5033120, 3146338, -1291214, 312874, -33953], 86 | [2082753, 9449717, -11271304, 16002320, -17283646, 13510082, -7394032, 2687864, -583435, 57281], 87 | [ 88 | 134211265, 656185652, -890175549, 1446205080, -1823311566, 1710774528, -1170597042, 567450984, -184776195, 89 | 36284876, -3250433 90 | ], 91 | [ 92 | 262747265, 1374799219, -2092490673, 3828828885, -5519460582, 6043521486, -4963166514, 3007739418, -1305971115, 93 | 384709327, -68928781, 5675265 94 | ], 95 | [ 96 | 703604254357, 3917551216986, -6616420957428, 13465774256510, -21847538039895, 27345870698436, -26204344465152, 97 | 19058185652796, -10344711794985, 4063327863170, -1092096992268, 179842822566, -13695779093 98 | ], 99 | [ 100 | 1382741929621, 8153167962181, -15141235084110, 33928990133618, -61188680131285, 86180228689563, -94393338653892, 101 | 80101021029180, -52177910882661, 25620259777835, -9181635605134, 2268078814386, -345457086395, 24466579093 102 | ], 103 | [ 104 | 8164168737599, 50770967534864, -102885148956217, 251724894607936, -499547203754837, 781911618071632, 105 | -963605400824733, 934600833490944, -710312834197347, 418551804601264, -187504936597931, 61759426692544, 106 | -14110480969927, 1998759236336, -132282840127 107 | ], 108 | [ 109 | 16088129229375, 105145058757073, -230992163723849, 612744541065337, -1326978663058069, 2285168598349733, 110 | -3129453071993581, 3414941728852893, -2966365730265699, 2039345879546643, -1096355235402331, 451403108933483, 111 | -137515713789319, 29219384284087, -3867689367599, 240208245823 112 | ], 113 | [ 114 | 8092989203533249, 55415287221275246, -131240807912923110, 375195469874202430, -880520318434977010, 115 | 1654462865819232198, -2492570347928318318, 3022404969160106870, -2953729295811279360, 2320851086013919370, 116 | -1455690451266780818, 719242466216944698, -273894214307914510, 77597639915764930, -15407325991235610, 117 | 1913813460537746, -111956703448001 118 | ], 119 | [ 120 | 15980174332775873, 114329243705491117, -290470969929371220, 890337710266029860, -2250854333681641520, 121 | 4582441343348851896, -7532171919277411636, 10047287575124288740, -10910555637627652470, 9644799218032932490, 122 | -6913858539337636636, 3985516155854664396, -1821304040326216520, 645008976643217360, -170761422500096220, 123 | 31816981024600492, -3722582669836627, 205804074290625 124 | ], 125 | [ 126 | 12600467236042756559, 93965550344204933076, -255007751875033918095, 834286388106402145800, 127 | -2260420115705863623660, 4956655592790542146968, -8827052559979384209108, 12845814402199484797800, 128 | -15345231910046032448070, 15072781455122686545920, -12155867625610599812538, 8008520809622324571288, 129 | -4269779992576330506540, 1814584564159445787240, -600505972582990474260, 149186846171741510136, 130 | -26182538841925312881, 2895045518506940460, -151711881512390095 131 | ], 132 | [ 133 | 24919383499187492303, 193280569173472261637, -558160720115629395555, 1941395668950986461335, 134 | -5612131802364455926260, 13187185898439270330756, -25293146116627869170796, 39878419226784442421820, 135 | -51970649453670274135470, 56154678684618739939910, -50320851025594566473146, 37297227252822858381906, 136 | -22726350407538133839300, 11268210124987992327060, -4474886658024166985340, 1389665263296211699212, 137 | -325187970422032795497, 53935307402575440285, -5652892248087175675, 281550972898020815 138 | ], 139 | ] 140 | 141 | _DIVISOR = [ 142 | None, 11, 2, 12, 24, 720, 1440, 60480, 120960, 3628800, 7257600, 479001600, 958003200, 2615348736000, 5230697472000, 143 | 31384184832000, 62768369664000, 32011868528640000, 64023737057280000, 51090942171709440000, 102181884343418880000 144 | ] 145 | 146 | _MIN_ORDER = 4 147 | _MAX_ORDER = 12 148 | _MAX_ITERS = 4 149 | 150 | 151 | class AdamsBashforthMoulton(FixedGridODESolver): 152 | 153 | def __init__( 154 | self, func, y0, rtol=1e-3, atol=1e-4, implicit=True, max_iters=_MAX_ITERS, max_order=_MAX_ORDER, **kwargs 155 | ): 156 | super(AdamsBashforthMoulton, self).__init__(func, y0, **kwargs) 157 | 158 | self.rtol = rtol 159 | self.atol = atol 160 | self.implicit = implicit 161 | self.max_iters = max_iters 162 | self.max_order = int(min(max_order, _MAX_ORDER)) 163 | self.prev_f = collections.deque(maxlen=self.max_order - 1) 164 | self.prev_t = None 165 | 166 | def _update_history(self, t, f): 167 | if self.prev_t is None or self.prev_t != t: 168 | self.prev_f.appendleft(f) 169 | self.prev_t = t 170 | 171 | def step_func(self, func, t, dt, y): 172 | self._update_history(t, func(t, y)) 173 | order = min(len(self.prev_f), self.max_order - 1) 174 | if order < _MIN_ORDER - 1: 175 | # Compute using RK4. 176 | dy = rk_common.rk4_alt_step_func(func, t, dt, y, k1=self.prev_f[0]) 177 | return dy 178 | else: 179 | # Adams-Bashforth predictor. 180 | bashforth_coeffs = _BASHFORTH_COEFFICIENTS[order] 181 | ab_div = _DIVISOR[order] 182 | dy = tuple(dt * _scaled_dot_product(1 / ab_div, bashforth_coeffs, f_) for f_ in zip(*self.prev_f)) 183 | 184 | # Adams-Moulton corrector. 185 | if self.implicit: 186 | moulton_coeffs = _MOULTON_COEFFICIENTS[order + 1] 187 | am_div = _DIVISOR[order + 1] 188 | delta = tuple(dt * _scaled_dot_product(1 / am_div, moulton_coeffs[1:], f_) for f_ in zip(*self.prev_f)) 189 | converged = False 190 | for _ in range(self.max_iters): 191 | dy_old = dy 192 | f = func(t + dt, tuple(y_ + dy_ for y_, dy_ in zip(y, dy))) 193 | dy = tuple(dt * (moulton_coeffs[0] / am_div) * f_ + delta_ for f_, delta_ in zip(f, delta)) 194 | converged = _has_converged(dy_old, dy, self.rtol, self.atol) 195 | if converged: 196 | break 197 | if not converged: 198 | print('Warning: Functional iteration did not converge. Solution may be incorrect.', file=sys.stderr) 199 | self.prev_f.pop() 200 | self._update_history(t, f) 201 | return dy 202 | 203 | @property 204 | def order(self): 205 | return 4 206 | 207 | 208 | class AdamsBashforth(AdamsBashforthMoulton): 209 | 210 | def __init__(self, func, y0, **kwargs): 211 | super(AdamsBashforth, self).__init__(func, y0, implicit=False, **kwargs) 212 | -------------------------------------------------------------------------------- /ode_nn/torchdiffeq/_impl/fixed_grid.py: -------------------------------------------------------------------------------- 1 | from .solvers import FixedGridODESolver 2 | from . import rk_common 3 | 4 | 5 | class Euler(FixedGridODESolver): 6 | 7 | def step_func(self, func, t, dt, y): 8 | return tuple(dt * f_ for f_ in func(t, y)) 9 | 10 | @property 11 | def order(self): 12 | return 1 13 | 14 | 15 | class Midpoint(FixedGridODESolver): 16 | 17 | def step_func(self, func, t, dt, y): 18 | y_mid = tuple(y_ + f_ * dt / 2 for y_, f_ in zip(y, func(t, y))) 19 | return tuple(dt * f_ for f_ in func(t + dt / 2, y_mid)) 20 | 21 | @property 22 | def order(self): 23 | return 2 24 | 25 | 26 | class RK4(FixedGridODESolver): 27 | 28 | def step_func(self, func, t, dt, y): 29 | return rk_common.rk4_alt_step_func(func, t, dt, y) 30 | 31 | @property 32 | def order(self): 33 | return 4 34 | -------------------------------------------------------------------------------- /ode_nn/torchdiffeq/_impl/interp.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from .misc import _convert_to_tensor, _dot_product 3 | 4 | 5 | def _interp_fit(y0, y1, y_mid, f0, f1, dt): 6 | """Fit coefficients for 4th order polynomial interpolation. 7 | 8 | Args: 9 | y0: function value at the start of the interval. 10 | y1: function value at the end of the interval. 11 | y_mid: function value at the mid-point of the interval. 12 | f0: derivative value at the start of the interval. 13 | f1: derivative value at the end of the interval. 14 | dt: width of the interval. 15 | 16 | Returns: 17 | List of coefficients `[a, b, c, d, e]` for interpolating with the polynomial 18 | `p = a * x ** 4 + b * x ** 3 + c * x ** 2 + d * x + e` for values of `x` 19 | between 0 (start of interval) and 1 (end of interval). 20 | """ 21 | a = tuple( 22 | _dot_product([-2 * dt, 2 * dt, -8, -8, 16], [f0_, f1_, y0_, y1_, y_mid_]) 23 | for f0_, f1_, y0_, y1_, y_mid_ in zip(f0, f1, y0, y1, y_mid) 24 | ) 25 | b = tuple( 26 | _dot_product([5 * dt, -3 * dt, 18, 14, -32], [f0_, f1_, y0_, y1_, y_mid_]) 27 | for f0_, f1_, y0_, y1_, y_mid_ in zip(f0, f1, y0, y1, y_mid) 28 | ) 29 | c = tuple( 30 | _dot_product([-4 * dt, dt, -11, -5, 16], [f0_, f1_, y0_, y1_, y_mid_]) 31 | for f0_, f1_, y0_, y1_, y_mid_ in zip(f0, f1, y0, y1, y_mid) 32 | ) 33 | d = tuple(dt * f0_ for f0_ in f0) 34 | e = y0 35 | return [a, b, c, d, e] 36 | 37 | 38 | def _interp_evaluate(coefficients, t0, t1, t): 39 | """Evaluate polynomial interpolation at the given time point. 40 | 41 | Args: 42 | coefficients: list of Tensor coefficients as created by `interp_fit`. 43 | t0: scalar float64 Tensor giving the start of the interval. 44 | t1: scalar float64 Tensor giving the end of the interval. 45 | t: scalar float64 Tensor giving the desired interpolation point. 46 | 47 | Returns: 48 | Polynomial interpolation of the coefficients at time `t`. 49 | """ 50 | 51 | dtype = coefficients[0][0].dtype 52 | device = coefficients[0][0].device 53 | 54 | t0 = _convert_to_tensor(t0, dtype=dtype, device=device) 55 | t1 = _convert_to_tensor(t1, dtype=dtype, device=device) 56 | t = _convert_to_tensor(t, dtype=dtype, device=device) 57 | 58 | assert (t0 <= t) & (t <= t1), 'invalid interpolation, fails `t0 <= t <= t1`: {}, {}, {}'.format(t0, t, t1) 59 | x = ((t - t0) / (t1 - t0)).type(dtype).to(device) 60 | 61 | xs = [torch.tensor(1).type(dtype).to(device), x] 62 | for _ in range(2, len(coefficients)): 63 | xs.append(xs[-1] * x) 64 | 65 | return tuple(_dot_product(coefficients_, reversed(xs)) for coefficients_ in zip(*coefficients)) 66 | -------------------------------------------------------------------------------- /ode_nn/torchdiffeq/_impl/misc.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | import torch 3 | 4 | 5 | def _flatten(sequence): 6 | flat = [p.contiguous().view(-1) for p in sequence] 7 | return torch.cat(flat) if len(flat) > 0 else torch.tensor([]) 8 | 9 | 10 | def _flatten_convert_none_to_zeros(sequence, like_sequence): 11 | flat = [ 12 | p.contiguous().view(-1) if p is not None else torch.zeros_like(q).view(-1) 13 | for p, q in zip(sequence, like_sequence) 14 | ] 15 | return torch.cat(flat) if len(flat) > 0 else torch.tensor([]) 16 | 17 | 18 | def _possibly_nonzero(x): 19 | return isinstance(x, torch.Tensor) or x != 0 20 | 21 | 22 | def _scaled_dot_product(scale, xs, ys): 23 | """Calculate a scaled, vector inner product between lists of Tensors.""" 24 | # Using _possibly_nonzero lets us avoid wasted computation. 25 | return sum([(scale * x) * y for x, y in zip(xs, ys) if _possibly_nonzero(x) or _possibly_nonzero(y)]) 26 | 27 | 28 | def _dot_product(xs, ys): 29 | """Calculate the vector inner product between two lists of Tensors.""" 30 | return sum([x * y for x, y in zip(xs, ys)]) 31 | 32 | 33 | def _has_converged(y0, y1, rtol, atol): 34 | """Checks that each element is within the error tolerance.""" 35 | error_tol = tuple(atol + rtol * torch.max(torch.abs(y0_), torch.abs(y1_)) for y0_, y1_ in zip(y0, y1)) 36 | error = tuple(torch.abs(y0_ - y1_) for y0_, y1_ in zip(y0, y1)) 37 | return all((error_ < error_tol_).all() for error_, error_tol_ in zip(error, error_tol)) 38 | 39 | 40 | def _convert_to_tensor(a, dtype=None, device=None): 41 | if not isinstance(a, torch.Tensor): 42 | a = torch.tensor(a) 43 | if dtype is not None: 44 | a = a.type(dtype) 45 | if device is not None: 46 | a = a.to(device) 47 | return a 48 | 49 | 50 | def _is_finite(tensor): 51 | _check = (tensor == float('inf')) + (tensor == float('-inf')) + torch.isnan(tensor) 52 | return not _check.any() 53 | 54 | 55 | def _decreasing(t): 56 | return (t[1:] < t[:-1]).all() 57 | 58 | 59 | def _assert_increasing(t): 60 | assert (t[1:] > t[:-1]).all(), 't must be strictly increasing or decreasing' 61 | 62 | 63 | def _is_iterable(inputs): 64 | try: 65 | iter(inputs) 66 | return True 67 | except TypeError: 68 | return False 69 | 70 | 71 | def _norm(x): 72 | """Compute RMS norm.""" 73 | if torch.is_tensor(x): 74 | return x.norm() / (x.numel()**0.5) 75 | else: 76 | return torch.sqrt(sum(x_.norm()**2 for x_ in x) / sum(x_.numel() for x_ in x)) 77 | 78 | 79 | def _handle_unused_kwargs(solver, unused_kwargs): 80 | if len(unused_kwargs) > 0: 81 | warnings.warn('{}: Unexpected arguments {}'.format(solver.__class__.__name__, unused_kwargs)) 82 | 83 | 84 | def _select_initial_step(fun, t0, y0, order, rtol, atol, f0=None): 85 | """Empirically select a good initial step. 86 | 87 | The algorithm is described in [1]_. 88 | 89 | Parameters 90 | ---------- 91 | fun : callable 92 | Right-hand side of the system. 93 | t0 : float 94 | Initial value of the independent variable. 95 | y0 : ndarray, shape (n,) 96 | Initial value of the dependent variable. 97 | direction : float 98 | Integration direction. 99 | order : float 100 | Method order. 101 | rtol : float 102 | Desired relative tolerance. 103 | atol : float 104 | Desired absolute tolerance. 105 | 106 | Returns 107 | ------- 108 | h_abs : float 109 | Absolute value of the suggested initial step. 110 | 111 | References 112 | ---------- 113 | .. [1] E. Hairer, S. P. Norsett G. Wanner, "Solving Ordinary Differential 114 | Equations I: Nonstiff Problems", Sec. II.4. 115 | """ 116 | t0 = t0.to(y0[0]) 117 | if f0 is None: 118 | f0 = fun(t0, y0) 119 | 120 | rtol = rtol if _is_iterable(rtol) else [rtol] * len(y0) 121 | atol = atol if _is_iterable(atol) else [atol] * len(y0) 122 | 123 | scale = tuple(atol_ + torch.abs(y0_) * rtol_ for y0_, atol_, rtol_ in zip(y0, atol, rtol)) 124 | 125 | d0 = tuple(_norm(y0_ / scale_) for y0_, scale_ in zip(y0, scale)) 126 | d1 = tuple(_norm(f0_ / scale_) for f0_, scale_ in zip(f0, scale)) 127 | 128 | if max(d0).item() < 1e-5 or max(d1).item() < 1e-5: 129 | h0 = torch.tensor(1e-6).to(t0) 130 | else: 131 | h0 = 0.01 * max(d0_ / d1_ for d0_, d1_ in zip(d0, d1)) 132 | 133 | y1 = tuple(y0_ + h0 * f0_ for y0_, f0_ in zip(y0, f0)) 134 | f1 = fun(t0 + h0, y1) 135 | 136 | d2 = tuple(_norm((f1_ - f0_) / scale_) / h0 for f1_, f0_, scale_ in zip(f1, f0, scale)) 137 | 138 | if max(d1).item() <= 1e-15 and max(d2).item() <= 1e-15: 139 | h1 = torch.max(torch.tensor(1e-6).to(h0), h0 * 1e-3) 140 | else: 141 | h1 = (0.01 / max(d1 + d2))**(1. / float(order + 1)) 142 | 143 | return torch.min(100 * h0, h1) 144 | 145 | 146 | def _compute_error_ratio(error_estimate, error_tol=None, rtol=None, atol=None, y0=None, y1=None): 147 | if error_tol is None: 148 | assert rtol is not None and atol is not None and y0 is not None and y1 is not None 149 | rtol if _is_iterable(rtol) else [rtol] * len(y0) 150 | atol if _is_iterable(atol) else [atol] * len(y0) 151 | error_tol = tuple( 152 | atol_ + rtol_ * torch.max(torch.abs(y0_), torch.abs(y1_)) 153 | for atol_, rtol_, y0_, y1_ in zip(atol, rtol, y0, y1) 154 | ) 155 | error_ratio = tuple(error_estimate_ / error_tol_ for error_estimate_, error_tol_ in zip(error_estimate, error_tol)) 156 | mean_sq_error_ratio = tuple(torch.mean(error_ratio_ * error_ratio_) for error_ratio_ in error_ratio) 157 | return mean_sq_error_ratio 158 | 159 | 160 | def _optimal_step_size(last_step, mean_error_ratio, safety=0.9, ifactor=10.0, dfactor=0.2, order=5): 161 | """Calculate the optimal size for the next step.""" 162 | mean_error_ratio = max(mean_error_ratio) # Compute step size based on highest ratio. 163 | if mean_error_ratio == 0: 164 | return last_step * ifactor 165 | if mean_error_ratio < 1: 166 | dfactor = _convert_to_tensor(1, dtype=torch.float64, device=mean_error_ratio.device) 167 | error_ratio = torch.sqrt(mean_error_ratio).to(last_step) 168 | exponent = torch.tensor(1 / order).to(last_step) 169 | factor = torch.max(1 / ifactor, torch.min(error_ratio**exponent / safety, 1 / dfactor)) 170 | return last_step / factor 171 | 172 | 173 | def _check_inputs(func, y0, t): 174 | tensor_input = False 175 | if torch.is_tensor(y0): 176 | tensor_input = True 177 | y0 = (y0,) 178 | _base_nontuple_func_ = func 179 | func = lambda t, y: (_base_nontuple_func_(t, y[0]),) 180 | assert isinstance(y0, tuple), 'y0 must be either a torch.Tensor or a tuple' 181 | for y0_ in y0: 182 | assert torch.is_tensor(y0_), 'each element must be a torch.Tensor but received {}'.format(type(y0_)) 183 | 184 | if _decreasing(t): 185 | t = -t 186 | _base_reverse_func = func 187 | func = lambda t, y: tuple(-f_ for f_ in _base_reverse_func(-t, y)) 188 | 189 | for y0_ in y0: 190 | if not torch.is_floating_point(y0_): 191 | raise TypeError('`y0` must be a floating point Tensor but is a {}'.format(y0_.type())) 192 | if not torch.is_floating_point(t): 193 | raise TypeError('`t` must be a floating point Tensor but is a {}'.format(t.type())) 194 | 195 | return tensor_input, func, y0, t 196 | -------------------------------------------------------------------------------- /ode_nn/torchdiffeq/_impl/odeint.py: -------------------------------------------------------------------------------- 1 | from .tsit5 import Tsit5Solver 2 | from .dopri5 import Dopri5Solver 3 | from .bosh3 import Bosh3Solver 4 | from .adaptive_heun import AdaptiveHeunSolver 5 | from .fixed_grid import Euler, Midpoint, RK4 6 | from .fixed_adams import AdamsBashforth, AdamsBashforthMoulton 7 | from .adams import VariableCoefficientAdamsBashforth 8 | from .dopri8 import Dopri8Solver 9 | from .misc import _check_inputs 10 | 11 | SOLVERS = { 12 | 'explicit_adams': AdamsBashforth, 13 | 'fixed_adams': AdamsBashforthMoulton, 14 | 'adams': VariableCoefficientAdamsBashforth, 15 | 'tsit5': Tsit5Solver, 16 | 'dopri5': Dopri5Solver, 17 | 'bosh3': Bosh3Solver, 18 | 'euler': Euler, 19 | 'midpoint': Midpoint, 20 | 'rk4': RK4, 21 | 'adaptive_heun': AdaptiveHeunSolver, 22 | 'dopri8': Dopri8Solver, 23 | } 24 | 25 | 26 | def odeint(func, y0, t, rtol=1e-7, atol=1e-9, method=None, options=None): 27 | """Integrate a system of ordinary differential equations. 28 | 29 | Solves the initial value problem for a non-stiff system of first order ODEs: 30 | ``` 31 | dy/dt = func(t, y), y(t[0]) = y0 32 | ``` 33 | where y is a Tensor or tuple of Tensors of any shape. 34 | 35 | Output dtypes and numerical precision are based on the dtypes of the inputs `y0`. 36 | 37 | Args: 38 | func: Function that maps a scalar Tensor `t` and a Tensor holding the state `y` 39 | into a Tensor of state derivatives with respect to time. Optionally, `y` 40 | can also be a tuple of Tensors. 41 | y0: N-D Tensor giving starting value of `y` at time point `t[0]`. Optionally, `y0` 42 | can also be a tuple of Tensors. 43 | t: 1-D Tensor holding a sequence of time points for which to solve for 44 | `y`. The initial time point should be the first element of this sequence, 45 | and each time must be larger than the previous time. 46 | rtol: optional float64 Tensor specifying an upper bound on relative error, 47 | per element of `y`. 48 | atol: optional float64 Tensor specifying an upper bound on absolute error, 49 | per element of `y`. 50 | method: optional string indicating the integration method to use. 51 | options: optional dict of configuring options for the indicated integration 52 | method. Can only be provided if a `method` is explicitly set. 53 | 54 | Returns: 55 | y: Tensor, where the first dimension corresponds to different 56 | time points. Contains the solved value of y for each desired time point in 57 | `t`, with the initial value `y0` being the first element along the first 58 | dimension. 59 | 60 | Raises: 61 | ValueError: if an invalid `method` is provided. 62 | """ 63 | 64 | tensor_input, func, y0, t = _check_inputs(func, y0, t) 65 | 66 | if options is None: 67 | options = {} 68 | elif method is None: 69 | raise ValueError('cannot supply `options` without specifying `method`') 70 | 71 | if method is None: 72 | method = 'dopri5' 73 | 74 | if method not in SOLVERS: 75 | raise ValueError('Invalid method "{}". Must be one of {}'.format( 76 | method, '{"' + '", "'.join(SOLVERS.keys()) + '"}.')) 77 | 78 | solver = SOLVERS[method](func, y0, rtol=rtol, atol=atol, **options) 79 | solution = solver.integrate(t) 80 | 81 | if tensor_input: 82 | solution = solution[0] 83 | return solution 84 | -------------------------------------------------------------------------------- /ode_nn/torchdiffeq/_impl/rk_common.py: -------------------------------------------------------------------------------- 1 | # Based on https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/integrate 2 | import collections 3 | from .misc import _scaled_dot_product, _convert_to_tensor 4 | 5 | _ButcherTableau = collections.namedtuple('_ButcherTableau', 'alpha beta c_sol c_error') 6 | 7 | 8 | class _RungeKuttaState(collections.namedtuple('_RungeKuttaState', 'y1, f1, t0, t1, dt, interp_coeff')): 9 | """Saved state of the Runge Kutta solver. 10 | 11 | Attributes: 12 | y1: Tensor giving the function value at the end of the last time step. 13 | f1: Tensor giving derivative at the end of the last time step. 14 | t0: scalar float64 Tensor giving start of the last time step. 15 | t1: scalar float64 Tensor giving end of the last time step. 16 | dt: scalar float64 Tensor giving the size for the next time step. 17 | interp_coef: list of Tensors giving coefficients for polynomial 18 | interpolation between `t0` and `t1`. 19 | """ 20 | 21 | 22 | def _runge_kutta_step(func, y0, f0, t0, dt, tableau): 23 | """Take an arbitrary Runge-Kutta step and estimate error. 24 | 25 | Args: 26 | func: Function to evaluate like `func(t, y)` to compute the time derivative 27 | of `y`. 28 | y0: Tensor initial value for the state. 29 | f0: Tensor initial value for the derivative, computed from `func(t0, y0)`. 30 | t0: float64 scalar Tensor giving the initial time. 31 | dt: float64 scalar Tensor giving the size of the desired time step. 32 | tableau: optional _ButcherTableau describing how to take the Runge-Kutta 33 | step. 34 | name: optional name for the operation. 35 | 36 | Returns: 37 | Tuple `(y1, f1, y1_error, k)` giving the estimated function value after 38 | the Runge-Kutta step at `t1 = t0 + dt`, the derivative of the state at `t1`, 39 | estimated error at `t1`, and a list of Runge-Kutta coefficients `k` used for 40 | calculating these terms. 41 | """ 42 | dtype = y0[0].dtype 43 | device = y0[0].device 44 | 45 | t0 = _convert_to_tensor(t0, dtype=dtype, device=device) 46 | dt = _convert_to_tensor(dt, dtype=dtype, device=device) 47 | 48 | k = tuple(map(lambda x: [x], f0)) 49 | for alpha_i, beta_i in zip(tableau.alpha, tableau.beta): 50 | ti = t0 + alpha_i * dt 51 | yi = tuple(y0_ + _scaled_dot_product(dt, beta_i, k_) for y0_, k_ in zip(y0, k)) 52 | tuple(k_.append(f_) for k_, f_ in zip(k, func(ti, yi))) 53 | 54 | if not (tableau.c_sol[-1] == 0 and tableau.c_sol[:-1] == tableau.beta[-1]): 55 | # This property (true for Dormand-Prince) lets us save a few FLOPs. 56 | yi = tuple(y0_ + _scaled_dot_product(dt, tableau.c_sol, k_) for y0_, k_ in zip(y0, k)) 57 | 58 | y1 = yi 59 | f1 = tuple(k_[-1] for k_ in k) 60 | y1_error = tuple(_scaled_dot_product(dt, tableau.c_error, k_) for k_ in k) 61 | return (y1, f1, y1_error, k) 62 | 63 | 64 | def rk4_step_func(func, t, dt, y, k1=None): 65 | if k1 is None: k1 = func(t, y) 66 | k2 = func(t + dt / 2, tuple(y_ + dt * k1_ / 2 for y_, k1_ in zip(y, k1))) 67 | k3 = func(t + dt / 2, tuple(y_ + dt * k2_ / 2 for y_, k2_ in zip(y, k2))) 68 | k4 = func(t + dt, tuple(y_ + dt * k3_ for y_, k3_ in zip(y, k3))) 69 | return tuple((k1_ + 2 * k2_ + 2 * k3_ + k4_) * (dt / 6) for k1_, k2_, k3_, k4_ in zip(k1, k2, k3, k4)) 70 | 71 | 72 | def rk4_alt_step_func(func, t, dt, y, k1=None): 73 | """Smaller error with slightly more compute.""" 74 | if k1 is None: k1 = func(t, y) 75 | k2 = func(t + dt / 3, tuple(y_ + dt * k1_ / 3 for y_, k1_ in zip(y, k1))) 76 | k3 = func(t + dt * 2 / 3, tuple(y_ + dt * (k1_ / -3 + k2_) for y_, k1_, k2_ in zip(y, k1, k2))) 77 | k4 = func(t + dt, tuple(y_ + dt * (k1_ - k2_ + k3_) for y_, k1_, k2_, k3_ in zip(y, k1, k2, k3))) 78 | return tuple((k1_ + 3 * k2_ + 3 * k3_ + k4_) * (dt / 8) for k1_, k2_, k3_, k4_ in zip(k1, k2, k3, k4)) 79 | -------------------------------------------------------------------------------- /ode_nn/torchdiffeq/_impl/solvers.py: -------------------------------------------------------------------------------- 1 | import abc 2 | import torch 3 | from .misc import _assert_increasing, _handle_unused_kwargs 4 | 5 | 6 | class AdaptiveStepsizeODESolver(object): 7 | __metaclass__ = abc.ABCMeta 8 | 9 | def __init__(self, func, y0, atol, rtol, **unused_kwargs): 10 | _handle_unused_kwargs(self, unused_kwargs) 11 | del unused_kwargs 12 | 13 | self.func = func 14 | self.y0 = y0 15 | self.atol = atol 16 | self.rtol = rtol 17 | 18 | def before_integrate(self, t): 19 | pass 20 | 21 | @abc.abstractmethod 22 | def advance(self, next_t): 23 | raise NotImplementedError 24 | 25 | def integrate(self, t): 26 | _assert_increasing(t) 27 | solution = [self.y0] 28 | t = t.to(self.y0[0].device, torch.float64) 29 | self.before_integrate(t) 30 | for i in range(1, len(t)): 31 | y = self.advance(t[i]) 32 | solution.append(y) 33 | return tuple(map(torch.stack, tuple(zip(*solution)))) 34 | 35 | 36 | class FixedGridODESolver(object): 37 | __metaclass__ = abc.ABCMeta 38 | 39 | def __init__(self, func, y0, step_size=None, grid_constructor=None, **unused_kwargs): 40 | unused_kwargs.pop('rtol', None) 41 | unused_kwargs.pop('atol', None) 42 | _handle_unused_kwargs(self, unused_kwargs) 43 | del unused_kwargs 44 | 45 | self.func = func 46 | self.y0 = y0 47 | 48 | if step_size is None: 49 | if grid_constructor is None: 50 | self.grid_constructor = lambda f, y0, t: t 51 | else: 52 | self.grid_constructor = grid_constructor 53 | else: 54 | if grid_constructor is None: 55 | self.grid_constructor = self._grid_constructor_from_step_size(step_size) 56 | else: 57 | raise ValueError("step_size and grid_constructor are exclusive arguments.") 58 | 59 | def _grid_constructor_from_step_size(self, step_size): 60 | 61 | def _grid_constructor(func, y0, t): 62 | start_time = t[0] 63 | end_time = t[-1] 64 | 65 | niters = torch.ceil((end_time - start_time) / step_size + 1).item() 66 | t_infer = torch.arange(0, niters).to(t) * step_size + start_time 67 | if t_infer[-1] > t[-1]: 68 | t_infer[-1] = t[-1] 69 | 70 | return t_infer 71 | 72 | return _grid_constructor 73 | 74 | @property 75 | @abc.abstractmethod 76 | def order(self): 77 | pass 78 | 79 | @abc.abstractmethod 80 | def step_func(self, func, t, dt, y): 81 | pass 82 | 83 | def integrate(self, t): 84 | _assert_increasing(t) 85 | t = t.type_as(self.y0[0]) 86 | time_grid = self.grid_constructor(self.func, self.y0, t) 87 | assert time_grid[0] == t[0] and time_grid[-1] == t[-1] 88 | time_grid = time_grid.to(self.y0[0]) 89 | 90 | solution = [self.y0] 91 | 92 | j = 1 93 | y0 = self.y0 94 | for t0, t1 in zip(time_grid[:-1], time_grid[1:]): 95 | dy = self.step_func(self.func, t0, t1 - t0, y0) 96 | y1 = tuple(y0_ + dy_ for y0_, dy_ in zip(y0, dy)) 97 | 98 | while j < len(t) and t1 >= t[j]: 99 | solution.append(self._linear_interp(t0, t1, y0, y1, t[j])) 100 | j += 1 101 | y0 = y1 102 | 103 | return tuple(map(torch.stack, tuple(zip(*solution)))) 104 | 105 | def _linear_interp(self, t0, t1, y0, y1, t): 106 | if t == t0: 107 | return y0 108 | if t == t1: 109 | return y1 110 | t0, t1, t = t0.to(y0[0]), t1.to(y0[0]), t.to(y0[0]) 111 | slope = tuple((y1_ - y0_) / (t1 - t0) for y0_, y1_, in zip(y0, y1)) 112 | return tuple(y0_ + slope_ * (t - t0) for y0_, slope_ in zip(y0, slope)) 113 | -------------------------------------------------------------------------------- /ode_nn/torchdiffeq/_impl/tsit5.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from .misc import _scaled_dot_product, _convert_to_tensor, _is_finite, _select_initial_step, _handle_unused_kwargs 3 | from .solvers import AdaptiveStepsizeODESolver 4 | from .rk_common import _RungeKuttaState, _ButcherTableau, _runge_kutta_step 5 | 6 | # Parameters from Tsitouras (2011). 7 | _TSITOURAS_TABLEAU = _ButcherTableau( 8 | alpha=[0.161, 0.327, 0.9, 0.9800255409045097, 1., 1.], 9 | beta=[ 10 | [0.161], 11 | [-0.008480655492357, 0.3354806554923570], 12 | [2.897153057105494, -6.359448489975075, 4.362295432869581], 13 | [5.32586482843925895, -11.74888356406283, 7.495539342889836, -0.09249506636175525], 14 | [5.86145544294642038, -12.92096931784711, 8.159367898576159, -0.071584973281401006, -0.02826905039406838], 15 | [0.09646076681806523, 0.01, 0.4798896504144996, 1.379008574103742, -3.290069515436081, 2.324710524099774], 16 | ], 17 | c_sol=[0.09646076681806523, 0.01, 0.4798896504144996, 1.379008574103742, -3.290069515436081, 2.324710524099774, 0], 18 | c_error=[ 19 | 0.09646076681806523 - 0.001780011052226, 20 | 0.01 - 0.000816434459657, 21 | 0.4798896504144996 - -0.007880878010262, 22 | 1.379008574103742 - 0.144711007173263, 23 | -3.290069515436081 - -0.582357165452555, 24 | 2.324710524099774 - 0.458082105929187, 25 | -1 / 66, 26 | ], 27 | ) 28 | 29 | 30 | def _interp_coeff_tsit5(t0, dt, eval_t): 31 | t = float((eval_t - t0) / dt) 32 | b1 = -1.0530884977290216 * t * (t - 1.3299890189751412) * (t**2 - 1.4364028541716351 * t + 0.7139816917074209) 33 | b2 = 0.1017 * t**2 * (t**2 - 2.1966568338249754 * t + 1.2949852507374631) 34 | b3 = 2.490627285651252793 * t**2 * (t**2 - 2.38535645472061657 * t + 1.57803468208092486) 35 | b4 = -16.54810288924490272 * (t - 1.21712927295533244) * (t - 0.61620406037800089) * t**2 36 | b5 = 47.37952196281928122 * (t - 1.203071208372362603) * (t - 0.658047292653547382) * t**2 37 | b6 = -34.87065786149660974 * (t - 1.2) * (t - 0.666666666666666667) * t**2 38 | b7 = 2.5 * (t - 1) * (t - 0.6) * t**2 39 | return [b1, b2, b3, b4, b5, b6, b7] 40 | 41 | 42 | def _interp_eval_tsit5(t0, t1, k, eval_t): 43 | dt = t1 - t0 44 | y0 = tuple(k_[0] for k_ in k) 45 | interp_coeff = _interp_coeff_tsit5(t0, dt, eval_t) 46 | y_t = tuple(y0_ + _scaled_dot_product(dt, interp_coeff, k_) for y0_, k_ in zip(y0, k)) 47 | return y_t 48 | 49 | 50 | def _optimal_step_size(last_step, mean_error_ratio, safety=0.9, ifactor=10.0, dfactor=0.2, order=5): 51 | """Calculate the optimal size for the next Runge-Kutta step.""" 52 | if mean_error_ratio == 0: 53 | return last_step * ifactor 54 | if mean_error_ratio < 1: 55 | dfactor = _convert_to_tensor(1, dtype=torch.float64, device=mean_error_ratio.device) 56 | error_ratio = torch.sqrt(mean_error_ratio).type_as(last_step) 57 | exponent = torch.tensor(1 / order).type_as(last_step) 58 | factor = torch.max(1 / ifactor, torch.min(error_ratio**exponent / safety, 1 / dfactor)) 59 | return last_step / factor 60 | 61 | 62 | def _abs_square(x): 63 | return torch.mul(x, x) 64 | 65 | 66 | class Tsit5Solver(AdaptiveStepsizeODESolver): 67 | 68 | def __init__( 69 | self, func, y0, rtol, atol, first_step=None, safety=0.9, ifactor=10.0, dfactor=0.2, max_num_steps=2**31 - 1, 70 | **unused_kwargs 71 | ): 72 | _handle_unused_kwargs(self, unused_kwargs) 73 | del unused_kwargs 74 | 75 | self.func = func 76 | self.y0 = y0 77 | self.rtol = rtol 78 | self.atol = atol 79 | self.first_step = first_step 80 | self.safety = _convert_to_tensor(safety, dtype=torch.float64, device=y0[0].device) 81 | self.ifactor = _convert_to_tensor(ifactor, dtype=torch.float64, device=y0[0].device) 82 | self.dfactor = _convert_to_tensor(dfactor, dtype=torch.float64, device=y0[0].device) 83 | self.max_num_steps = _convert_to_tensor(max_num_steps, dtype=torch.int32, device=y0[0].device) 84 | 85 | def before_integrate(self, t): 86 | if self.first_step is None: 87 | first_step = _select_initial_step(self.func, t[0], self.y0, 4, self.rtol, self.atol).to(t) 88 | else: 89 | first_step = _convert_to_tensor(self.first_step, dtype=t.dtype, device=t.device) 90 | self.rk_state = _RungeKuttaState( 91 | self.y0, 92 | self.func(t[0].type_as(self.y0[0]), self.y0), t[0], t[0], first_step, 93 | tuple(map(lambda x: [x] * 7, self.y0)) 94 | ) 95 | 96 | def advance(self, next_t): 97 | """Interpolate through the next time point, integrating as necessary.""" 98 | n_steps = 0 99 | while next_t > self.rk_state.t1: 100 | assert n_steps < self.max_num_steps, 'max_num_steps exceeded ({}>={})'.format(n_steps, self.max_num_steps) 101 | self.rk_state = self._adaptive_tsit5_step(self.rk_state) 102 | n_steps += 1 103 | return _interp_eval_tsit5(self.rk_state.t0, self.rk_state.t1, self.rk_state.interp_coeff, next_t) 104 | 105 | def _adaptive_tsit5_step(self, rk_state): 106 | """Take an adaptive Runge-Kutta step to integrate the ODE.""" 107 | y0, f0, _, t0, dt, _ = rk_state 108 | ######################################################## 109 | # Assertions # 110 | ######################################################## 111 | assert t0 + dt > t0, 'underflow in dt {}'.format(dt.item()) 112 | for y0_ in y0: 113 | assert _is_finite(torch.abs(y0_)), 'non-finite values in state `y`: {}'.format(y0_) 114 | y1, f1, y1_error, k = _runge_kutta_step(self.func, y0, f0, t0, dt, tableau=_TSITOURAS_TABLEAU) 115 | 116 | ######################################################## 117 | # Error Ratio # 118 | ######################################################## 119 | error_tol = tuple(self.atol + self.rtol * torch.max(torch.abs(y0_), torch.abs(y1_)) for y0_, y1_ in zip(y0, y1)) 120 | tensor_error_ratio = tuple(y1_error_ / error_tol_ for y1_error_, error_tol_ in zip(y1_error, error_tol)) 121 | sq_error_ratio = tuple( 122 | torch.mul(tensor_error_ratio_, tensor_error_ratio_) for tensor_error_ratio_ in tensor_error_ratio 123 | ) 124 | mean_error_ratio = ( 125 | sum(torch.sum(sq_error_ratio_) for sq_error_ratio_ in sq_error_ratio) / 126 | sum(sq_error_ratio_.numel() for sq_error_ratio_ in sq_error_ratio) 127 | ) 128 | accept_step = mean_error_ratio <= 1 129 | 130 | ######################################################## 131 | # Update RK State # 132 | ######################################################## 133 | y_next = y1 if accept_step else y0 134 | f_next = f1 if accept_step else f0 135 | t_next = t0 + dt if accept_step else t0 136 | dt_next = _optimal_step_size(dt, mean_error_ratio, self.safety, self.ifactor, self.dfactor) 137 | k_next = k if accept_step else self.rk_state.interp_coeff 138 | rk_state = _RungeKuttaState(y_next, f_next, t0, t_next, dt_next, k_next) 139 | return rk_state 140 | -------------------------------------------------------------------------------- /ode_nn/train.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import numpy as np 4 | from torch.utils import data 5 | device = torch.device("cpu")#torch.device("cuda" if torch.cuda.is_available() else 6 | 7 | def get_lr(optimizer): 8 | for param_group in optimizer.param_groups: 9 | return param_group['lr'] 10 | 11 | class Dataset(data.Dataset): 12 | def __init__(self, indices, input_length, mid, output_length, direc, entire_target = False, N = 1.0): 13 | self.mid = mid 14 | self.input_length = input_length 15 | self.output_length = output_length 16 | self.direc = direc 17 | self.list_IDs = indices 18 | self.entire_target = entire_target 19 | self.N = N 20 | 21 | def __len__(self): 22 | return len(self.list_IDs) 23 | 24 | def __getitem__(self, index): 25 | ID = self.list_IDs[index] 26 | sample = torch.load(self.direc + str(ID) + ".pt") 27 | x = sample[(self.mid-self.input_length):self.mid]/self.N 28 | if self.entire_target: 29 | y = sample[(self.mid-self.input_length):(self.mid+self.output_length)]/self.N 30 | else: 31 | y = sample[self.mid:(self.mid+self.output_length)]/self.N 32 | return x.float(), y.float() 33 | 34 | 35 | def train_epoch(model, data_loader, optimizer, loss_fun, feed_tgt = False): 36 | preds = [] 37 | trues = [] 38 | mse = [] 39 | for xx, yy in data_loader: 40 | xx, yy = xx.to(device), yy.to(device) 41 | 42 | if feed_tgt: 43 | yy_pred = model(xx, yy.shape[1], yy) 44 | else: 45 | yy_pred = model(xx, yy.shape[1]) 46 | 47 | loss = torch.zeros_like(yy)#loss = 0 48 | for i, q in enumerate([0.25, 0.5, 0.75]): 49 | e = yy_pred[:,:,:,i] - yy 50 | loss += torch.max(q * e, (q - 1) * e) 51 | 52 | loss = loss.mean()#loss = loss_fun(yy_pred, yy) 53 | mse.append(loss.item()) # 54 | optimizer.zero_grad() 55 | loss.backward() 56 | optimizer.step() 57 | trues.append(yy.cpu().data.numpy()) 58 | preds.append(yy_pred.cpu().data.numpy()) 59 | 60 | preds = np.concatenate(preds, axis = 0) 61 | trues = np.concatenate(trues, axis = 0) 62 | return preds, trues, np.mean(mse) #np.round(np.sqrt(np.mean(mse)), 5) 63 | 64 | def eval_epoch(model, data_loader, loss_fun, concat_input = False): 65 | preds = [] 66 | trues = [] 67 | mse = [] 68 | with torch.no_grad(): 69 | for xx, yy in data_loader: 70 | xx, yy = xx.to(device), yy.to(device) 71 | yy_pred = model(xx, yy.shape[1]) 72 | loss = torch.zeros_like(yy) 73 | for i, q in enumerate([0.25, 0.5, 0.75]): 74 | e = yy_pred[:,:,:,i] - yy 75 | loss += torch.max(q * e, (q - 1) * e) 76 | loss = loss.mean() 77 | mse.append(loss.item())#loss = loss_fun(yy_pred, yy) 78 | if concat_input: 79 | trues.append(torch.cat([xx, yy], dim = 1).cpu().data.numpy()) 80 | else: 81 | trues.append(yy.cpu().data.numpy()) 82 | preds.append(yy_pred.cpu().data.numpy()) 83 | 84 | preds = np.concatenate(preds, axis = 0) 85 | trues = np.concatenate(trues, axis = 0) 86 | 87 | return preds, trues, np.mean(mse)#np.round(np.sqrt(np.mean(mse)), 5) 88 | 89 | class Dataset_graph(data.Dataset): 90 | def __init__(self, indices, input_length, mid, output_length, direc, entire_target = False, N = 1.0, stack = True): 91 | self.mid = mid 92 | self.input_length = input_length 93 | self.output_length = output_length 94 | self.direc = direc 95 | self.list_IDs = indices 96 | self.entire_target = entire_target 97 | self.N = N 98 | self.stack = stack 99 | 100 | def __len__(self): 101 | return len(self.list_IDs) 102 | 103 | def __getitem__(self, index): 104 | ID = self.list_IDs[index] 105 | sample = torch.load(self.direc + str(ID) + ".pt") 106 | x = sample[:,(self.mid-self.input_length):self.mid]/self.N 107 | if self.entire_target: 108 | y = sample[:,(self.mid-self.input_length):(self.mid+self.output_length)]/self.N 109 | else: 110 | y = sample[:, self.mid:(self.mid+self.output_length)]/self.N 111 | if self.stack: 112 | return x.reshape(x.shape[0], -1).float(), y.float() 113 | return x.float(), y.float() 114 | 115 | def train_epoch_graph(model, data_loader, optimizer, loss_fun, graph): 116 | preds = [] 117 | trues = [] 118 | mse = [] 119 | for xx, yy in data_loader: 120 | xx, yy = xx.to(device), yy.to(device) 121 | loss = 0 122 | yy_pred = model(graph, xx, yy.shape[2]) 123 | loss = loss_fun(yy_pred, yy) 124 | mse.append(loss.item()) 125 | trues.append(yy.cpu().data.numpy()) 126 | preds.append(yy_pred.cpu().data.numpy()) 127 | optimizer.zero_grad() 128 | loss.backward() 129 | optimizer.step() 130 | 131 | preds = np.concatenate(preds, axis = 0) 132 | trues = np.concatenate(trues, axis = 0) 133 | return preds, trues, np.round(np.sqrt(np.mean(mse)), 5) 134 | 135 | def eval_epoch_graph(model, data_loader, loss_fun, graph): 136 | preds = [] 137 | trues = [] 138 | mse = [] 139 | with torch.no_grad(): 140 | for xx, yy in data_loader: 141 | xx, yy = xx.to(device), yy.to(device) 142 | loss = 0 143 | yy_pred = model(graph, xx, yy.shape[2]) 144 | loss = loss_fun(yy_pred, yy) 145 | mse.append(loss.item()) 146 | if yy.shape[1] != 60: 147 | trues.append(torch.cat([xx.reshape(xx.shape[0], xx.shape[1], -1, 3), yy], dim = 2).cpu().data.numpy()) 148 | else: 149 | trues.append(yy.cpu().data.numpy()) 150 | preds.append(yy_pred.cpu().data.numpy()) 151 | 152 | preds = np.concatenate(preds, axis = 0) 153 | trues = np.concatenate(trues, axis = 0) 154 | 155 | return preds, trues, np.round(np.sqrt(np.mean(mse)), 5) --------------------------------------------------------------------------------