├── .gitattributes
├── .gitignore
├── LICENSE
├── README.md
├── data
├── airlines.dat
├── airports.json
├── covid_data
│ ├── X.pkl
│ ├── from_state_to_id.pkl
│ ├── g.pkl
│ ├── y_cases.pkl
│ ├── y_cases_normalized.pkl
│ ├── y_deaths.pkl
│ └── y_deaths_normalized.pkl
├── gp_on_graphs_teaser.png
├── heat_distribution
│ ├── 1d.json
│ ├── 1d.pkl
│ └── 2d.pkl
├── hungary_chicken_pox
│ ├── hungary_chickenpox.csv
│ ├── hungary_county_edges.csv
│ └── nn_dataset.json
├── st99_d00.dbf
├── st99_d00.shp
├── st99_d00.shx
├── us-states.csv
├── wave
│ ├── X.pkl
│ ├── graph.pkl
│ └── y.pkl
└── weather
│ ├── X.pkl
│ ├── g.pkl
│ ├── g_100.pkl
│ ├── weekly
│ ├── X.pkl
│ ├── g.pkl
│ ├── g_100.pkl
│ ├── g_50.pkl
│ └── y_temprature.pkl
│ └── y_temprature.pkl
├── experiments
├── 1d_experiments.py
├── 1d_wave_experiments.ipynb
├── run_chicken_pox.py
└── run_covid_experiments.py
├── graph_kernels
├── __init__.py
├── data_utils.py
├── kernels.py
├── time_kernels.py
├── utils.py
├── utils_opt.py
└── utils_postproc.py
├── requirements.txt
└── setup.py
/.gitattributes:
--------------------------------------------------------------------------------
1 | data/** filter=lfs diff=lfs merge=lfs -text
2 | *.ipynb linguist-detectable=false
3 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | pip-wheel-metadata/
24 | share/python-wheels/
25 | *.egg-info/
26 | .installed.cfg
27 | *.egg
28 | MANIFEST
29 |
30 | # PyInstaller
31 | # Usually these files are written by a python script from a template
32 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
33 | *.manifest
34 | *.spec
35 |
36 | # Installer logs
37 | pip-log.txt
38 | pip-delete-this-directory.txt
39 |
40 | # Unit test / coverage reports
41 | htmlcov/
42 | .tox/
43 | .nox/
44 | .coverage
45 | .coverage.*
46 | .cache
47 | nosetests.xml
48 | coverage.xml
49 | *.cover
50 | *.py,cover
51 | .hypothesis/
52 | .pytest_cache/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | target/
76 |
77 | # Jupyter Notebook
78 | .ipynb_checkpoints
79 |
80 | # IPython
81 | profile_default/
82 | ipython_config.py
83 |
84 | # pyenv
85 | .python-version
86 |
87 | # pipenv
88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
91 | # install all needed dependencies.
92 | #Pipfile.lock
93 |
94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
95 | __pypackages__/
96 |
97 | # Celery stuff
98 | celerybeat-schedule
99 | celerybeat.pid
100 |
101 | # SageMath parsed files
102 | *.sage.py
103 |
104 | # Environments
105 | .env
106 | .venv
107 | env/
108 | venv/
109 | ENV/
110 | env.bak/
111 | venv.bak/
112 |
113 | # Spyder project settings
114 | .spyderproject
115 | .spyproject
116 |
117 | # Rope project settings
118 | .ropeproject
119 |
120 | # mkdocs documentation
121 | /site
122 |
123 | # mypy
124 | .mypy_cache/
125 | .dmypy.json
126 | dmypy.json
127 |
128 | # Pyre type checker
129 | .pyre/
130 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright [yyyy] [name of copyright owner]
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
202 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Non-separable Spatio-temporal Graph Kernels via SPDEs
2 |
3 | This repository is the official implementation of the methods in the publication
4 | * Alexander Nikitin, ST John, Arno Solin, and Samuel Kaski (2022). **Non-separable spatio-temporal graph kernels via SPDEs**. In *Proceedings of the 25th International Conference on Artificial Intelligence and Statistics (AISTATS)*. [[arXiv]](https://arxiv.org/abs/2111.08524)
5 |
6 |
7 |
8 |
9 |
10 |
11 | We leverage an explicit link between stochastic partial differential equations (SPDEs) and Gaussian processes on graphs and derive non-separable spatio-temporal graph kernels that capture interaction across space and time. We formulate the graph kernels for the stochastic heat equation and wave equation. We show that by providing novel tools for spatio-temporal GP modelling on graphs, we outperform pre-existing graph kernels in real-world applications that feature diffusion, oscillation, and other complicated interactions.
12 |
13 | ## Use
14 | The repo uses [git-lfs](https://git-lfs.github.com/) to store datasets. To fetch the data use:
15 | ```bash
16 | git lfs fetch
17 | ```
18 |
19 | The code was tested with `python==3.6` and should work for `python>=3.6`.
20 |
21 | To install the required packages, run:
22 | ```bash
23 | pip install -r requirements.txt
24 | pip install -e .
25 | ```
26 |
27 | ## Structure
28 | The repository contains two sets of kernels for time-independent and temporal processes on graphs.
29 | * Time-independent kernels are stored in `graph_kernels/kernels.py`.
30 | * Temporal kernels are stored in `graph_kernels/time_kernels.py`.
31 | * SHEK and SWEK are implemented in `graph_kernels/time_kernels.py:StochasticHeatEquation` and `graph_kernels/time_kernels.py:StochasticWaveEquationKernel`.
32 |
33 | ## Experiments.
34 | We provide an experimental evaluation of the proposed kernels on several datasets.
35 |
36 | ### Heat Transfer Dataset
37 | #### Interpolation:
38 | ```bash
39 | python experiments/1d_experiments.py --interpolation --dump_directory=$PATH_TO_RESULTS
40 | ```
41 |
42 | #### Extrapolation:
43 | ```bash
44 | python experiments/1d_experiments.py --extrapolation --dump_directory=$PATH_TO_RESULTS
45 | ```
46 |
47 | ### Chickenpox experiments
48 | #### Interpolation (103 + num_test_weeks):
49 | ```bash
50 | python experiments/run_chicken_pox.py --num_test_weeks=2 --interpolation --dump_directory=$PATH_TO_RESULTS
51 | ```
52 |
53 | #### Extrapolation:
54 | ```bash
55 | python experiments/run_chicken_pox.py --num_test_weeks=2 --extrapolation --dump_directory=$PATH_TO_RESULTS
56 | ```
57 |
58 |
59 | ### Covid19 Experiments
60 | #### Interpolation (33 + num_test_weeks):
61 | ```bash
62 | python experiments/run_covid_experiments.py --log_target --no-use_flight_graph \
63 | --no-use_normalized_target --num_test_weeks=2 --interpolation --dump_directory=$PATH_TO_RESULTS
64 | ```
65 |
66 | #### Extrapolation:
67 | ```bash
68 | python experiments/run_covid_experiments.py --log_target --no-use_flight_graph \
69 | --no-use_normalized_target --num_test_weeks=2 --interpolation --dump_directory=$PATH_TO_RESULTS
70 | ```
71 |
72 |
73 | ### Wave Experiments:
74 | Open with jupyter-notebook:
75 | ```bash
76 | ./experiments/1d_wave_experiments.ipynb
77 | ```
78 |
79 | ## Citation
80 | If you use the code in this repository for your research, please cite the paper as follows:
81 | ```bibtex
82 | @inproceedings{nikitin2022non,
83 | title={Non-separable spatio-temporal graph kernels via SPDEs},
84 | author={Nikitin, Alexander V and John, ST and Solin, Arno and Kaski, Samuel},
85 | booktitle={International Conference on Artificial Intelligence and Statistics},
86 | pages={10640--10660},
87 | year={2022},
88 | organization={PMLR}
89 | }
90 | ```
91 |
92 | ## Contributing
93 | For all correspondence, please contact alexander.nikitin@aalto.fi.
94 |
95 | ## License
96 | This software is provided under the [Apache License 2.0](LICENSE).
97 |
--------------------------------------------------------------------------------
/data/airlines.dat:
--------------------------------------------------------------------------------
1 | version https://git-lfs.github.com/spec/v1
2 | oid sha256:39be1a432e8b04ebc12860c29281c974a9cb52169c82b2456a835d66ab1548a1
3 | size 396896
4 |
--------------------------------------------------------------------------------
/data/airports.json:
--------------------------------------------------------------------------------
1 | version https://git-lfs.github.com/spec/v1
2 | oid sha256:406e6e0162b6896d3fb3dd17a9d5332ab90c643392e7643d3f66d2e584f74d8a
3 | size 1673771
4 |
--------------------------------------------------------------------------------
/data/covid_data/X.pkl:
--------------------------------------------------------------------------------
1 | version https://git-lfs.github.com/spec/v1
2 | oid sha256:e792ec7372d61cab61410ab175f9b9940bb00864a905c17b4a11fa2850766cbd
3 | size 65439
4 |
--------------------------------------------------------------------------------
/data/covid_data/from_state_to_id.pkl:
--------------------------------------------------------------------------------
1 | version https://git-lfs.github.com/spec/v1
2 | oid sha256:5b351b7f0cfef000d37bc884bf4a2f1d74f293376cf55f66dd79f07fe8bacd57
3 | size 909
4 |
--------------------------------------------------------------------------------
/data/covid_data/g.pkl:
--------------------------------------------------------------------------------
1 | version https://git-lfs.github.com/spec/v1
2 | oid sha256:cc975ea9245ec2cac3b305830a4c9e933ce576599d113f51cf55d8a6fb63dcf3
3 | size 3021
4 |
--------------------------------------------------------------------------------
/data/covid_data/y_cases.pkl:
--------------------------------------------------------------------------------
1 | version https://git-lfs.github.com/spec/v1
2 | oid sha256:0b5f12d595f5dbdeeec0c5a348d099d001d131829d64ab723a95386c721620fa
3 | size 32797
4 |
--------------------------------------------------------------------------------
/data/covid_data/y_cases_normalized.pkl:
--------------------------------------------------------------------------------
1 | version https://git-lfs.github.com/spec/v1
2 | oid sha256:1670a82e886976f28800d89e0b3861e4f56509d2c9cba30e456ad2014c6343a0
3 | size 32797
4 |
--------------------------------------------------------------------------------
/data/covid_data/y_deaths.pkl:
--------------------------------------------------------------------------------
1 | version https://git-lfs.github.com/spec/v1
2 | oid sha256:ee730ed51369519bfbc51fbeb0c4c1b64f6abc2903ca2788d0916efe7d050713
3 | size 32797
4 |
--------------------------------------------------------------------------------
/data/covid_data/y_deaths_normalized.pkl:
--------------------------------------------------------------------------------
1 | version https://git-lfs.github.com/spec/v1
2 | oid sha256:1670a82e886976f28800d89e0b3861e4f56509d2c9cba30e456ad2014c6343a0
3 | size 32797
4 |
--------------------------------------------------------------------------------
/data/gp_on_graphs_teaser.png:
--------------------------------------------------------------------------------
1 | version https://git-lfs.github.com/spec/v1
2 | oid sha256:da50d90fdff5fa83c6f110a5c9b040c2d74fc09bbcd9bf683fc0f460f68b34c2
3 | size 174792
4 |
--------------------------------------------------------------------------------
/data/heat_distribution/1d.json:
--------------------------------------------------------------------------------
1 | version https://git-lfs.github.com/spec/v1
2 | oid sha256:feeaf015b7351bd0a2544c65276c82f43480ea087279cb01480ca14302d4ebcf
3 | size 4510280
4 |
--------------------------------------------------------------------------------
/data/heat_distribution/1d.pkl:
--------------------------------------------------------------------------------
1 | version https://git-lfs.github.com/spec/v1
2 | oid sha256:381712337d50cecfc4c64a11b545e67a2b8e4b8509f7a09e304e25e429792b07
3 | size 106433
4 |
--------------------------------------------------------------------------------
/data/heat_distribution/2d.pkl:
--------------------------------------------------------------------------------
1 | version https://git-lfs.github.com/spec/v1
2 | oid sha256:8864cb27c84ded9e327ab469418edb659e024fab6dc829204b9c83dbb3776472
3 | size 18179330
4 |
--------------------------------------------------------------------------------
/data/hungary_chicken_pox/hungary_chickenpox.csv:
--------------------------------------------------------------------------------
1 | version https://git-lfs.github.com/spec/v1
2 | oid sha256:bd1b802f1059b9b0f0d2c29d6171bc2f302994f81a92d7c7bb25191f76dc4a23
3 | size 35065
4 |
--------------------------------------------------------------------------------
/data/hungary_chicken_pox/hungary_county_edges.csv:
--------------------------------------------------------------------------------
1 | version https://git-lfs.github.com/spec/v1
2 | oid sha256:0061c603b935f3722b7c621e545a30f9129b0c0c619bd4bd30e62a7975c8d609
3 | size 1850
4 |
--------------------------------------------------------------------------------
/data/hungary_chicken_pox/nn_dataset.json:
--------------------------------------------------------------------------------
1 | version https://git-lfs.github.com/spec/v1
2 | oid sha256:60b7ed962c6714c8ff7d661cecae91b649322ae4bed82e634bd0ad4089a3ff40
3 | size 198780
4 |
--------------------------------------------------------------------------------
/data/st99_d00.dbf:
--------------------------------------------------------------------------------
1 | version https://git-lfs.github.com/spec/v1
2 | oid sha256:f391d2e0a9d32ad60716459cf9080b5cd5ac120f07eeff15b947a7b161d89896
3 | size 57411
4 |
--------------------------------------------------------------------------------
/data/st99_d00.shp:
--------------------------------------------------------------------------------
1 | version https://git-lfs.github.com/spec/v1
2 | oid sha256:1be585a5335ae51b08f359b9a506798a4f6c36a247c7965d655b31e5e394bd4c
3 | size 2300316
4 |
--------------------------------------------------------------------------------
/data/st99_d00.shx:
--------------------------------------------------------------------------------
1 | version https://git-lfs.github.com/spec/v1
2 | oid sha256:50abf5eba8312c13bf13f72e60ca8aa595e8bf1f390a2b32e5fa93b6f5964f32
3 | size 2284
4 |
--------------------------------------------------------------------------------
/data/us-states.csv:
--------------------------------------------------------------------------------
1 | version https://git-lfs.github.com/spec/v1
2 | oid sha256:eed8fc98a9bbe4ebc17b7544ae85501ab51aeeab1210f6dd3679e3a5deedb9e7
3 | size 978592
4 |
--------------------------------------------------------------------------------
/data/wave/X.pkl:
--------------------------------------------------------------------------------
1 | version https://git-lfs.github.com/spec/v1
2 | oid sha256:e470f28c4c4d71b9c1eba8de477cc5647eae2c0f84ddd3732a6ecc1ffd4bfafe
3 | size 35711
4 |
--------------------------------------------------------------------------------
/data/wave/graph.pkl:
--------------------------------------------------------------------------------
1 | version https://git-lfs.github.com/spec/v1
2 | oid sha256:002561d92508d0ae33dde23558cc3c809ea634dcb5efa5caf42fc8f9707b1852
3 | size 518
4 |
--------------------------------------------------------------------------------
/data/wave/y.pkl:
--------------------------------------------------------------------------------
1 | version https://git-lfs.github.com/spec/v1
2 | oid sha256:c59a149e4ff739cb86e72827498411bcacdc100b3dec861873b6815f2bb868b9
3 | size 17933
4 |
--------------------------------------------------------------------------------
/data/weather/X.pkl:
--------------------------------------------------------------------------------
1 | version https://git-lfs.github.com/spec/v1
2 | oid sha256:9274bf21ef637c636e5507d3ef1efff0d3e1d98a058e064f08dd1a1439a5032e
3 | size 336159
4 |
--------------------------------------------------------------------------------
/data/weather/g.pkl:
--------------------------------------------------------------------------------
1 | version https://git-lfs.github.com/spec/v1
2 | oid sha256:797c017691bf3a08f93062a25ec307f3edf621aa2bd3d1db630a010916ecdb16
3 | size 100281
4 |
--------------------------------------------------------------------------------
/data/weather/g_100.pkl:
--------------------------------------------------------------------------------
1 | version https://git-lfs.github.com/spec/v1
2 | oid sha256:720789313b8b32929c7f413d1fa5e910b083006c4baad3353a5c5b54f2006699
3 | size 17844
4 |
--------------------------------------------------------------------------------
/data/weather/weekly/X.pkl:
--------------------------------------------------------------------------------
1 | version https://git-lfs.github.com/spec/v1
2 | oid sha256:b04cc32817cde5d487a5d0e798e2f7fd36ca5d7169afb3c5b1d389edcc1d7042
3 | size 84159
4 |
--------------------------------------------------------------------------------
/data/weather/weekly/g.pkl:
--------------------------------------------------------------------------------
1 | version https://git-lfs.github.com/spec/v1
2 | oid sha256:4bcc125b71a9b87affbd118b7a0077860e8d9861817002df3344f9dea67feeaf
3 | size 8383
4 |
--------------------------------------------------------------------------------
/data/weather/weekly/g_100.pkl:
--------------------------------------------------------------------------------
1 | version https://git-lfs.github.com/spec/v1
2 | oid sha256:4bcc125b71a9b87affbd118b7a0077860e8d9861817002df3344f9dea67feeaf
3 | size 8383
4 |
--------------------------------------------------------------------------------
/data/weather/weekly/g_50.pkl:
--------------------------------------------------------------------------------
1 | version https://git-lfs.github.com/spec/v1
2 | oid sha256:a624678a52a683d42bab6bb28dfdff8f4acce283978fbab50baf5a91055e1ef9
3 | size 3915
4 |
--------------------------------------------------------------------------------
/data/weather/weekly/y_temprature.pkl:
--------------------------------------------------------------------------------
1 | version https://git-lfs.github.com/spec/v1
2 | oid sha256:f0b78d600134701c9fb08b98153cd831764133d6d44e7b03bf01d773b6477704
3 | size 162114
4 |
--------------------------------------------------------------------------------
/data/weather/y_temprature.pkl:
--------------------------------------------------------------------------------
1 | version https://git-lfs.github.com/spec/v1
2 | oid sha256:67502d11f400a9e6e92d542dae790dd7b272b4b6da60f966a574b546067e2bc5
3 | size 650394
4 |
--------------------------------------------------------------------------------
/experiments/1d_experiments.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | import json
4 | import copy
5 |
6 | import numpy as np
7 | from tqdm import tqdm
8 |
9 | import sklearn
10 | import sklearn.metrics
11 |
12 | import gpflow
13 | from gpflow import Parameter
14 | import tensorflow as tf
15 |
16 | from graph_kernels import data_utils
17 | from graph_kernels import time_kernels
18 | from graph_kernels import utils_opt
19 | from graph_kernels import utils
20 |
21 |
22 | def parse_arguments():
23 | parser = argparse.ArgumentParser(description='Heat distribution over a 1d line.')
24 |
25 | group = parser.add_mutually_exclusive_group()
26 | group.add_argument('--interpolation', action='store_true', default=False,
27 | help='Evaluate the models on the interpolation task.')
28 | group.add_argument('--extrapolation', dest='interpolation', action='store_false')
29 |
30 | parser.add_argument('--dump_directory', type=str, help='Path to directory with results.',
31 | default="dump_directory")
32 | return parser.parse_args()
33 |
34 |
35 | args = parse_arguments()
36 | INTERPOLATION = args.interpolation
37 | DUMP_DIRECTORY = args.dump_directory
38 | if not os.path.exists(DUMP_DIRECTORY):
39 | os.makedirs(DUMP_DIRECTORY)
40 |
41 |
42 | DATASET_PATH_1d = os.path.join(
43 | os.path.dirname(os.path.abspath(__file__)), "../data/heat_distribution/1d.pkl")
44 | DATASET_PATH_2s = os.path.join(
45 | os.path.dirname(os.path.abspath(__file__)), "../data/heat_distribution/2d.pkl")
46 |
47 | N_ITER = 2000
48 | RANDOM_SEEDS = [23, 42, 82, 100, 2 * 23, 2 * 42, 2 * 82, 2 * 100]
49 | NUM_TRAIN = 50 # number of training timestamps
50 | NUM_TEST = 10
51 |
52 | gpflow.config.set_default_jitter(1e-8)
53 |
54 |
55 | class ConstantArray(gpflow.mean_functions.MeanFunction):
56 | def __init__(self, shape):
57 | super().__init__()
58 | c = tf.zeros(shape)
59 | self.c = Parameter(c)
60 |
61 | def __call__(self, X):
62 | return tf.reshape(
63 | tf.gather(self.c, tf.cast(X[:, 0], dtype=tf.int32)),
64 | (X.shape[0], 1)
65 | )
66 |
67 |
68 | def convert_dataset_to_nodes(data):
69 | new_data = []
70 | for row in data:
71 | new_data.append([from_x_to_nodes[row[0]], row[1]])
72 | return np.array(new_data)
73 |
74 |
75 | def extract_ml_dataset(dataset_1d, times):
76 | data = []
77 | target = []
78 | for t in times:
79 | for i, el in enumerate(dataset_1d[t]["x"]):
80 | data.append(np.append(el, np.array(t)))
81 | target.append(dataset_1d[t]["y"][i])
82 | data = np.array(data)
83 | target = np.array(target)[:, np.newaxis]
84 | return data, target
85 |
86 |
87 | def evaluate_kernel(kernel, kernel_name):
88 | results = {}
89 | for random_seed in tqdm(RANDOM_SEEDS):
90 | utils.set_all_random_seeds(random_seed)
91 |
92 | train_data, train_target, test_data, test_target = datasets[random_seed]
93 | if kernel_name != "td_exponential":
94 | train_data = convert_dataset_to_nodes(train_data)
95 | test_data = convert_dataset_to_nodes(test_data)
96 | mean_function = ConstantArray(num_nodes)
97 | else:
98 | mean_function = gpflow.mean_functions.Constant()
99 | print("Shape: ", train_data.shape, test_data.shape)
100 | result, gprocess = utils_opt.evaluate_kernel_mcmc(
101 | copy.deepcopy(kernel), train_data, train_target,
102 | test_data, test_target, graph, mean_function=copy.deepcopy(mean_function), n_iter=N_ITER,
103 | optimizer_name="LBFGS")
104 |
105 | results[random_seed] = result
106 | return results, gprocess
107 |
108 |
109 | dataset_1d = data_utils.read_heat_1d(DATASET_PATH_1d)
110 |
111 | t = list(dataset_1d.keys())[0]
112 | graph = data_utils.build_graph_from_1d_points(dataset_1d[t]["x"])
113 |
114 | num_nodes = len(graph.nodes())
115 |
116 | from_x_to_nodes = {data["point"]: node for node, data in graph.nodes(data=True)}
117 |
118 | times = sorted(dataset_1d.keys())[1:]
119 |
120 | datasets = {}
121 | for i, rs in enumerate(RANDOM_SEEDS):
122 | start = i
123 | start_testing = start + NUM_TRAIN
124 | train_times = times[start:start_testing]
125 | test_times = times[start_testing:start_testing + NUM_TEST]
126 |
127 | train_data, train_target = extract_ml_dataset(dataset_1d, train_times)
128 | test_data, test_target = extract_ml_dataset(dataset_1d, test_times)
129 | if INTERPOLATION:
130 | train_data, test_data, train_target, test_target = \
131 | sklearn.model_selection.train_test_split(
132 | np.concatenate((train_data, test_data)),
133 | np.concatenate((train_target, test_target)), test_size=0.1, random_state=rs)
134 |
135 | datasets[rs] = (train_data, train_target, test_data, test_target)
136 |
137 |
138 | kernels = {
139 | "td_exponential": time_kernels.TimeDistributed1dExponentialKernel(graph),
140 | # "td_laplacian": time_kernels.TimeDistributedLaplacianKernel(graph),
141 | "td_matern_nu_52_d_1": time_kernels.TimeDistributedMaternKernel(graph, nu=5 / 2, kappa=1),
142 | "td_matern_nu_32_d_1": time_kernels.TimeDistributedMaternKernel(graph, nu=3 / 2, kappa=1),
143 | "td_matern_nu_12_d_1": time_kernels.TimeDistributedMaternKernel(graph, nu=1 / 2, kappa=1),
144 | "stoch_heat_vector_pseudo_diff_1": time_kernels.StochasticHeatEquation(
145 | graph, c=0.1, use_pseudodifferential=True, nu=5 / 2,
146 | kappa=1, variance=[1.] * len(graph.nodes())),
147 | "stoch_heat_vector_pseudo_diff_2": time_kernels.StochasticHeatEquation(
148 | graph, c=0.1, use_pseudodifferential=True, nu=3 / 2,
149 | kappa=1, variance=[1.] * len(graph.nodes())),
150 | "stoch_heat_vector_pseudo_diff_3": time_kernels.StochasticHeatEquation(
151 | graph, c=0.1, use_pseudodifferential=True, nu=1 / 2,
152 | kappa=1, variance=[1.] * len(graph.nodes())),
153 | }
154 |
155 | for kernel_name, kernel in kernels.items():
156 | print("Evaluating {}".format(kernel_name))
157 | result, gprocess = evaluate_kernel(
158 | copy.deepcopy(kernel), kernel_name)
159 | folder = os.path.join(DUMP_DIRECTORY, kernel_name)
160 | os.makedirs(folder, exist_ok=True)
161 | # pickle.dump(gprocess, open(os.path.join(folder, "gprocess.pkl"), "wb"))
162 | json.dump(result, open(os.path.join(folder, "result.json"), "w"))
163 |
--------------------------------------------------------------------------------
/experiments/1d_wave_experiments.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": 1,
6 | "metadata": {},
7 | "outputs": [
8 | {
9 | "data": {
10 | "text/plain": [
11 | ""
12 | ]
13 | },
14 | "metadata": {},
15 | "output_type": "display_data"
16 | }
17 | ],
18 | "source": [
19 | "%load_ext autoreload\n",
20 | "%autoreload 2\n",
21 | "\n",
22 | "from IPython.display import HTML\n",
23 | "\n",
24 | "import pickle\n",
25 | "\n",
26 | "import seaborn as sns\n",
27 | "import collections\n",
28 | "import networkx as nx\n",
29 | "import copy\n",
30 | "import time\n",
31 | "import os\n",
32 | "\n",
33 | "import numpy as np\n",
34 | "import matplotlib.pyplot as plt\n",
35 | "\n",
36 | "from matplotlib.pyplot import figure\n",
37 | "\n",
38 | "import gpflow\n",
39 | "\n",
40 | "from graph_kernels import data_utils\n",
41 | "from graph_kernels import utils\n",
42 | "from graph_kernels import time_kernels\n",
43 | "from graph_kernels import utils_opt\n",
44 | "from graph_kernels import utils_postproc\n",
45 | "\n",
46 | "figure(num=None, figsize=(28, 28), dpi=80, facecolor='w', edgecolor='k')\n",
47 | "\n",
48 | "gpflow.config.set_default_jitter(1e-4)"
49 | ]
50 | },
51 | {
52 | "cell_type": "code",
53 | "execution_count": 13,
54 | "metadata": {},
55 | "outputs": [],
56 | "source": [
57 | "PATH_TO_DATA_FOLDER = (\"../data/wave/\")\n",
58 | "PATH_X = os.path.join(PATH_TO_DATA_FOLDER, \"X.pkl\")\n",
59 | "PATH_Y = os.path.join(PATH_TO_DATA_FOLDER, \"y.pkl\")\n",
60 | "PATH_GRAPH = os.path.join(PATH_TO_DATA_FOLDER, \"graph.pkl\")\n",
61 | "graph = pickle.load(open(PATH_GRAPH, \"rb\"))\n",
62 | "N_NODES = len(graph.nodes())\n",
63 | "\n",
64 | "DUMP_DIRECTORY = \"dump_directory\"\n",
65 | "os.makedirs(DUMP_DIRECTORY, exist_ok=True)\n",
66 | "os.makedirs(\"images\", exist_ok=True)"
67 | ]
68 | },
69 | {
70 | "cell_type": "code",
71 | "execution_count": 3,
72 | "metadata": {},
73 | "outputs": [],
74 | "source": [
75 | "NUM_TEST_WEEKS = 2\n",
76 | "NUM_TRAIN = 50 * N_NODES\n",
77 | "NUM_TEST = NUM_TEST_WEEKS * N_NODES\n",
78 | "START = 4 * N_NODES * 2\n",
79 | "N_ITER = 5_000\n",
80 | "RANDOM_SEEDS = [23, 42, 82, 100, 123, 223,\n",
81 | " 2 * 23, 2 * 42, 2 * 82, 2 * 100, 2 * 123, 2 * 223]"
82 | ]
83 | },
84 | {
85 | "cell_type": "code",
86 | "execution_count": 4,
87 | "metadata": {},
88 | "outputs": [],
89 | "source": [
90 | "X = pickle.load(open(PATH_X, \"rb\"))\n",
91 | "y = pickle.load(open(PATH_Y, \"rb\"))"
92 | ]
93 | },
94 | {
95 | "cell_type": "code",
96 | "execution_count": 5,
97 | "metadata": {},
98 | "outputs": [],
99 | "source": [
100 | "kernels = {\n",
101 | " \"td_laplacian\": time_kernels.TimeDistributedLaplacianKernel(graph),\n",
102 | " \"td_matern_nu_52_k_1\": time_kernels.TimeDistributedMaternKernel(graph, nu=5/2, kappa=1),\n",
103 | " \"td_matern_nu_32_k_1\": time_kernels.TimeDistributedMaternKernel(graph, nu=3/2, kappa=1),\n",
104 | " \"td_matern_nu_12_k_1\": time_kernels.TimeDistributedMaternKernel(graph, nu=1/2, kappa=1),\n",
105 | " \"stoch_heat_vector_pseudo_diff_1\": time_kernels.StochasticHeatEquation(graph,\n",
106 | " c=0.1, use_pseudodifferential=True, nu=5/2,\n",
107 | " kappa=1, variance=[1.]*len(graph.nodes())),\n",
108 | " \"stoch_heat_vector_pseudo_diff_2\": time_kernels.StochasticHeatEquation(graph,\n",
109 | " c=0.1, use_pseudodifferential=True, nu=3/2,\n",
110 | " kappa=1, variance=[1.]*len(graph.nodes())),\n",
111 | " \"stoch_heat_vector_pseudo_diff_3\": time_kernels.StochasticHeatEquation(graph,\n",
112 | " c=0.1, use_pseudodifferential=True, nu=1/2,\n",
113 | " kappa=1, variance=1.),\n",
114 | " \"stoch_wave_kernel_nu_12\": time_kernels.StochasticWaveEquationKernel(\n",
115 | " graph, c=0.1, use_pseudodifferential=True, nu=1 / 2, kappa=10,\n",
116 | " variance=1.0),\n",
117 | " \"stoch_wave_kernel_nu_32\": time_kernels.StochasticWaveEquationKernel(\n",
118 | " graph, c=0.1, use_pseudodifferential=True, nu=3 / 2, kappa=1,\n",
119 | " variance=1.0),\n",
120 | " \"stoch_wave_kernel_nu_52\": time_kernels.StochasticWaveEquationKernel(\n",
121 | " graph, c=0.1, use_pseudodifferential=True, nu=5 / 2, kappa=10,\n",
122 | " variance=1.0),\n",
123 | "}"
124 | ]
125 | },
126 | {
127 | "cell_type": "code",
128 | "execution_count": 6,
129 | "metadata": {},
130 | "outputs": [],
131 | "source": [
132 | "results = collections.defaultdict(dict)"
133 | ]
134 | },
135 | {
136 | "cell_type": "code",
137 | "execution_count": 7,
138 | "metadata": {},
139 | "outputs": [
140 | {
141 | "name": "stdout",
142 | "output_type": "stream",
143 | "text": [
144 | "0:\tELBO: -522.78695\tMAPE: 49411083745101.2500000000\tMAE: 0.0494760289\n",
145 | "1:\tELBO: -649.60072\tMAPE: 87218169830742.5468750000\tMAE: 0.0515282197\n",
146 | "2:\tELBO: -824.59367\tMAPE: 90807941555610.4375000000\tMAE: 0.0511540748\n",
147 | "3:\tELBO: -955.31851\tMAPE: 68342742878235.9453125000\tMAE: 0.0404082756\n",
148 | "4:\tELBO: -1152.45533\tMAPE: 35426937814843.6562500000\tMAE: 0.0236137586\n",
149 | "5:\tELBO: -1193.72487\tMAPE: 41640826968279.1250000000\tMAE: 0.0249512851\n",
150 | "6:\tELBO: -1221.01163\tMAPE: 47917634274505.5546875000\tMAE: 0.0315043598\n",
151 | "7:\tELBO: -1256.06121\tMAPE: 40104203824625.6015625000\tMAE: 0.0250922668\n",
152 | "8:\tELBO: -1260.73102\tMAPE: 39114375409547.6875000000\tMAE: 0.0243952481\n",
153 | "9:\tELBO: -1266.31820\tMAPE: 38759128574627.7812500000\tMAE: 0.0245500426\n",
154 | "10:\tELBO: -1271.73075\tMAPE: 37424899169070.4375000000\tMAE: 0.0247418769\n",
155 | "11:\tELBO: -1276.20968\tMAPE: 37220355215276.4453125000\tMAE: 0.0242883255\n",
156 | "12:\tELBO: -1283.45992\tMAPE: 36377933381502.3750000000\tMAE: 0.0246828267\n",
157 | "13:\tELBO: -1288.13708\tMAPE: 36276255408576.1562500000\tMAE: 0.0249266317\n",
158 | "14:\tELBO: -1289.39792\tMAPE: 37828748534609.7421875000\tMAE: 0.0252464624\n",
159 | "15:\tELBO: -1289.72658\tMAPE: 37351560268436.4765625000\tMAE: 0.0250686593\n",
160 | "16:\tELBO: -1289.78160\tMAPE: 37590105937283.8828125000\tMAE: 0.0250816727\n",
161 | "17:\tELBO: -1289.81950\tMAPE: 37718362948373.6171875000\tMAE: 0.0250789229\n",
162 | "18:\tELBO: -1289.91168\tMAPE: 37912337725456.7109375000\tMAE: 0.0250336248\n",
163 | "19:\tELBO: -1290.00858\tMAPE: 37990765436280.9062500000\tMAE: 0.0250376842\n",
164 | "20:\tELBO: -1290.04343\tMAPE: 37811904701063.3671875000\tMAE: 0.0248566383\n",
165 | "21:\tELBO: -1290.14220\tMAPE: 37726453581146.7812500000\tMAE: 0.0249401615\n",
166 | "22:\tELBO: -1290.22413\tMAPE: 37456102859620.1328125000\tMAE: 0.0249979521\n",
167 | "23:\tELBO: -1290.26739\tMAPE: 37203011272484.7031250000\tMAE: 0.0250672947\n",
168 | "24:\tELBO: -1290.28505\tMAPE: 37021202556996.2734375000\tMAE: 0.0250241352\n",
169 | "25:\tELBO: -1290.31220\tMAPE: 36875243251738.9921875000\tMAE: 0.0250548885\n",
170 | "26:\tELBO: -1290.33380\tMAPE: 36670146296457.4687500000\tMAE: 0.0250621257\n",
171 | "27:\tELBO: -1290.33825\tMAPE: 36638160318912.6953125000\tMAE: 0.0251000477\n",
172 | "28:\tELBO: -1290.34270\tMAPE: 36533621588011.8671875000\tMAE: 0.0250764243\n",
173 | "29:\tELBO: -1290.34484\tMAPE: 36458184152902.2968750000\tMAE: 0.0250685636\n",
174 | "30:\tELBO: -1290.34591\tMAPE: 36417236830188.9843750000\tMAE: 0.0250725603\n",
175 | "31:\tELBO: -1290.34686\tMAPE: 36405260022223.0078125000\tMAE: 0.0250828099\n",
176 | "32:\tELBO: -1290.34896\tMAPE: 36391605242716.1328125000\tMAE: 0.0251038633\n",
177 | "33:\tELBO: -1290.35328\tMAPE: 36386873986147.7343750000\tMAE: 0.0251220274\n",
178 | "34:\tELBO: -1290.36344\tMAPE: 36396310676440.8515625000\tMAE: 0.0251853621\n",
179 | "35:\tELBO: -1290.37104\tMAPE: 36215908826574.3515625000\tMAE: 0.0251737857\n",
180 | "36:\tELBO: -1290.38026\tMAPE: 36275447139535.4062500000\tMAE: 0.0253036889\n",
181 | "37:\tELBO: -1290.38255\tMAPE: 36275886204780.6015625000\tMAE: 0.0252891037\n",
182 | "38:\tELBO: -1290.38327\tMAPE: 36260072099426.7500000000\tMAE: 0.0252777955\n",
183 | "39:\tELBO: -1290.38367\tMAPE: 36258613944832.1015625000\tMAE: 0.0252740860\n",
184 | "40:\tELBO: -1290.38400\tMAPE: 36229128932267.3359375000\tMAE: 0.0252708193\n",
185 | "41:\tELBO: -1290.38459\tMAPE: 36234916872106.7187500000\tMAE: 0.0252716454\n",
186 | "42:\tELBO: -1290.38487\tMAPE: 36233358363031.0937500000\tMAE: 0.0252738838\n",
187 | "43:\tELBO: -1290.38496\tMAPE: 36232455782840.6718750000\tMAE: 0.0252780025\n",
188 | "44:\tELBO: -1290.38506\tMAPE: 36231895072822.7343750000\tMAE: 0.0252793214\n",
189 | "45:\tELBO: -1290.38515\tMAPE: 36231788775926.3593750000\tMAE: 0.0252803320\n",
190 | "46:\tELBO: -1290.38525\tMAPE: 36233244305607.2265625000\tMAE: 0.0252819269\n",
191 | "47:\tELBO: -1290.38535\tMAPE: 36239829064610.5390625000\tMAE: 0.0252764902\n",
192 | "48:\tELBO: -1290.38547\tMAPE: 36259148400086.1250000000\tMAE: 0.0252774194\n",
193 | "49:\tELBO: -1290.38555\tMAPE: 36278348998613.0468750000\tMAE: 0.0252769212\n",
194 | "50:\tELBO: -1290.38560\tMAPE: 36298055148564.8750000000\tMAE: 0.0252764348\n",
195 | "51:\tELBO: -1290.38560\tMAPE: 36302075428952.1171875000\tMAE: 0.0252772613\n",
196 | "52:\tELBO: -1290.38560\tMAPE: 36300656613550.1796875000\tMAE: 0.0252769210\n",
197 | "0:\tELBO: -521.07341\tMAPE: 53051592797944.7734375000\tMAE: 0.0403417802\n",
198 | "1:\tELBO: -637.39032\tMAPE: 85838123329290.0000000000\tMAE: 0.0499818521\n",
199 | "2:\tELBO: -794.32688\tMAPE: 80315799447217.0468750000\tMAE: 0.0455918271\n",
200 | "3:\tELBO: -950.49424\tMAPE: 50455065962070.2812500000\tMAE: 0.0335602561\n",
201 | "4:\tELBO: -1162.52089\tMAPE: 25293899769938.0468750000\tMAE: 0.0186312915\n",
202 | "5:\tELBO: -1203.47081\tMAPE: 28925213382676.9531250000\tMAE: 0.0188311960\n",
203 | "6:\tELBO: -1225.72753\tMAPE: 36400418601630.1484375000\tMAE: 0.0249787542\n",
204 | "7:\tELBO: -1252.80187\tMAPE: 30863340026413.0234375000\tMAE: 0.0191225974\n",
205 | "8:\tELBO: -1256.92322\tMAPE: 29207991150778.0468750000\tMAE: 0.0185416853\n",
206 | "9:\tELBO: -1261.48635\tMAPE: 27999807249039.6992187500\tMAE: 0.0185146359\n",
207 | "10:\tELBO: -1267.30300\tMAPE: 25747644833444.5859375000\tMAE: 0.0190772265\n",
208 | "11:\tELBO: -1272.36907\tMAPE: 24759491749926.3632812500\tMAE: 0.0195360430\n",
209 | "12:\tELBO: -1275.34238\tMAPE: 23915682410714.3437500000\tMAE: 0.0190782279\n",
210 | "13:\tELBO: -1279.97129\tMAPE: 23455410514548.2578125000\tMAE: 0.0187608310\n",
211 | "14:\tELBO: -1282.59399\tMAPE: 24517625070189.4218750000\tMAE: 0.0190111756\n",
212 | "15:\tELBO: -1283.45947\tMAPE: 24744791104060.4492187500\tMAE: 0.0191856242\n",
213 | "16:\tELBO: -1284.06306\tMAPE: 25998465817491.7773437500\tMAE: 0.0196145847\n",
214 | "17:\tELBO: -1284.23570\tMAPE: 25960147107733.9335937500\tMAE: 0.0195656276\n",
215 | "18:\tELBO: -1284.79248\tMAPE: 25813699279330.3710937500\tMAE: 0.0194245320\n",
216 | "19:\tELBO: -1284.83396\tMAPE: 25595705488161.6875000000\tMAE: 0.0193885085\n",
217 | "20:\tELBO: -1285.02767\tMAPE: 25687822847533.1718750000\tMAE: 0.0193525897\n",
218 | "21:\tELBO: -1285.14457\tMAPE: 25727419718968.5507812500\tMAE: 0.0194988424\n",
219 | "22:\tELBO: -1285.17864\tMAPE: 25656719008983.5742187500\tMAE: 0.0193896804\n",
220 | "23:\tELBO: -1285.22124\tMAPE: 25596549294383.5664062500\tMAE: 0.0193997784\n",
221 | "24:\tELBO: -1285.26066\tMAPE: 25564218477120.2578125000\tMAE: 0.0195108118\n",
222 | "25:\tELBO: -1285.28841\tMAPE: 25373618391193.3476562500\tMAE: 0.0193913635\n",
223 | "26:\tELBO: -1285.30969\tMAPE: 25327002426793.8242187500\tMAE: 0.0194376667\n",
224 | "27:\tELBO: -1285.34567\tMAPE: 25114742300525.7500000000\tMAE: 0.0195384952\n",
225 | "28:\tELBO: -1285.36203\tMAPE: 24770633615381.0898437500\tMAE: 0.0196209582\n",
226 | "29:\tELBO: -1285.36871\tMAPE: 24493440832749.2382812500\tMAE: 0.0196704521\n",
227 | "30:\tELBO: -1285.37270\tMAPE: 24555196542566.5781250000\tMAE: 0.0196443852\n",
228 | "31:\tELBO: -1285.37364\tMAPE: 24541241796232.8164062500\tMAE: 0.0196413299\n",
229 | "32:\tELBO: -1285.37446\tMAPE: 24499532180577.1523437500\tMAE: 0.0196488690\n",
230 | "33:\tELBO: -1285.37472\tMAPE: 24473287943903.0781250000\tMAE: 0.0196586874\n",
231 | "34:\tELBO: -1285.37478\tMAPE: 24463546111732.8710937500\tMAE: 0.0196669098\n",
232 | "35:\tELBO: -1285.37481\tMAPE: 24466785835111.2031250000\tMAE: 0.0196701347\n",
233 | "36:\tELBO: -1285.37483\tMAPE: 24470562039468.4062500000\tMAE: 0.0196711721\n",
234 | "37:\tELBO: -1285.37493\tMAPE: 24487427852971.9804687500\tMAE: 0.0196762357\n",
235 | "38:\tELBO: -1285.37502\tMAPE: 24492320119836.5390625000\tMAE: 0.0196809189\n",
236 | "39:\tELBO: -1285.37508\tMAPE: 24514239133188.3281250000\tMAE: 0.0196852435\n",
237 | "40:\tELBO: -1285.37514\tMAPE: 24503984547086.8398437500\tMAE: 0.0196857878\n",
238 | "41:\tELBO: -1285.37518\tMAPE: 24493889456120.7265625000\tMAE: 0.0196869308\n",
239 | "42:\tELBO: -1285.37520\tMAPE: 24493668066128.2031250000\tMAE: 0.0196870576\n",
240 | "43:\tELBO: -1285.37521\tMAPE: 24492942045649.6679687500\tMAE: 0.0196930923\n",
241 | "44:\tELBO: -1285.37522\tMAPE: 24498887437306.1132812500\tMAE: 0.0196895703\n",
242 | "45:\tELBO: -1285.37522\tMAPE: 24500649112213.1406250000\tMAE: 0.0196895311\n",
243 | "0:\tELBO: -519.59184\tMAPE: 54638454531394.9531250000\tMAE: 0.0333110624\n",
244 | "1:\tELBO: -631.99642\tMAPE: 82680744880616.5468750000\tMAE: 0.0437322019\n",
245 | "2:\tELBO: -781.40598\tMAPE: 67294044918960.8828125000\tMAE: 0.0410882433\n",
246 | "3:\tELBO: -948.78567\tMAPE: 31152417693185.1484375000\tMAE: 0.0234719582\n",
247 | "4:\tELBO: -1148.92590\tMAPE: 15449952606350.5019531250\tMAE: 0.0177633879\n",
248 | "5:\tELBO: -1202.11183\tMAPE: 17389284108381.2167968750\tMAE: 0.0120175991\n",
249 | "6:\tELBO: -1223.90298\tMAPE: 20950906160427.4218750000\tMAE: 0.0140121172\n",
250 | "7:\tELBO: -1241.35473\tMAPE: 22613655741948.3593750000\tMAE: 0.0155389304\n",
251 | "8:\tELBO: -1254.77144\tMAPE: 18443910140354.6093750000\tMAE: 0.0131489969\n",
252 | "9:\tELBO: -1257.80750\tMAPE: 18248126347920.6679687500\tMAE: 0.0130912943\n",
253 | "10:\tELBO: -1262.56495\tMAPE: 17969626087017.8203125000\tMAE: 0.0145590158\n",
254 | "11:\tELBO: -1266.18126\tMAPE: 15803009057356.4453125000\tMAE: 0.0144326908\n",
255 | "12:\tELBO: -1269.89810\tMAPE: 14535597325946.2011718750\tMAE: 0.0147423155\n",
256 | "13:\tELBO: -1274.43604\tMAPE: 14151515975149.9941406250\tMAE: 0.0147181544\n",
257 | "14:\tELBO: -1278.51541\tMAPE: 13183391196713.0312500000\tMAE: 0.0133701219\n",
258 | "15:\tELBO: -1279.10665\tMAPE: 17058990516505.0175781250\tMAE: 0.0132238684\n"
259 | ]
260 | },
261 | {
262 | "name": "stdout",
263 | "output_type": "stream",
264 | "text": [
265 | "16:\tELBO: -1280.36431\tMAPE: 15546901558672.8574218750\tMAE: 0.0129894903\n",
266 | "17:\tELBO: -1280.44238\tMAPE: 15385211309965.7792968750\tMAE: 0.0129066530\n",
267 | "18:\tELBO: -1280.50935\tMAPE: 15475481464456.0136718750\tMAE: 0.0128813178\n",
268 | "19:\tELBO: -1280.61932\tMAPE: 15571246280553.9003906250\tMAE: 0.0128116411\n",
269 | "20:\tELBO: -1280.78383\tMAPE: 15676141405980.3632812500\tMAE: 0.0127725975\n",
270 | "21:\tELBO: -1280.84693\tMAPE: 15608953450839.4980468750\tMAE: 0.0129810770\n",
271 | "22:\tELBO: -1281.08525\tMAPE: 15378824660799.0917968750\tMAE: 0.0128606970\n",
272 | "23:\tELBO: -1281.24122\tMAPE: 15253448055387.7890625000\tMAE: 0.0128130038\n",
273 | "24:\tELBO: -1281.29833\tMAPE: 14957941209521.7167968750\tMAE: 0.0127707116\n",
274 | "25:\tELBO: -1281.32528\tMAPE: 14757430140586.4433593750\tMAE: 0.0127867230\n",
275 | "26:\tELBO: -1281.34251\tMAPE: 14787518812107.8808593750\tMAE: 0.0128147471\n",
276 | "27:\tELBO: -1281.36941\tMAPE: 14432609089545.5058593750\tMAE: 0.0127594236\n",
277 | "28:\tELBO: -1281.37592\tMAPE: 14284464106198.2109375000\tMAE: 0.0127099218\n",
278 | "29:\tELBO: -1281.37761\tMAPE: 14226667785677.9179687500\tMAE: 0.0126935910\n",
279 | "30:\tELBO: -1281.37925\tMAPE: 14176556288156.0488281250\tMAE: 0.0126853638\n",
280 | "31:\tELBO: -1281.38072\tMAPE: 14133894926985.5800781250\tMAE: 0.0126895069\n",
281 | "32:\tELBO: -1281.38451\tMAPE: 14090043135888.5117187500\tMAE: 0.0127092474\n",
282 | "33:\tELBO: -1281.38911\tMAPE: 13960413741569.9121093750\tMAE: 0.0127147823\n",
283 | "34:\tELBO: -1281.39658\tMAPE: 14039474139445.9941406250\tMAE: 0.0127784276\n",
284 | "35:\tELBO: -1281.40337\tMAPE: 14183619239077.5136718750\tMAE: 0.0128344683\n",
285 | "36:\tELBO: -1281.40694\tMAPE: 14316273947017.6054687500\tMAE: 0.0128495543\n",
286 | "37:\tELBO: -1281.40755\tMAPE: 14345385047397.7128906250\tMAE: 0.0128606281\n",
287 | "38:\tELBO: -1281.40839\tMAPE: 14371497543167.8867187500\tMAE: 0.0128523758\n",
288 | "39:\tELBO: -1281.40910\tMAPE: 14391751515555.0058593750\tMAE: 0.0128454590\n",
289 | "40:\tELBO: -1281.40967\tMAPE: 14392642780911.1933593750\tMAE: 0.0128410146\n",
290 | "41:\tELBO: -1281.41044\tMAPE: 14372597508016.8144531250\tMAE: 0.0128387525\n",
291 | "42:\tELBO: -1281.41104\tMAPE: 14349582304957.5058593750\tMAE: 0.0128509165\n",
292 | "43:\tELBO: -1281.41145\tMAPE: 14331809286883.5546875000\tMAE: 0.0128558697\n",
293 | "44:\tELBO: -1281.41161\tMAPE: 14315892796275.5546875000\tMAE: 0.0128638189\n",
294 | "45:\tELBO: -1281.41164\tMAPE: 14313428216734.7832031250\tMAE: 0.0128660344\n",
295 | "46:\tELBO: -1281.41164\tMAPE: 14320820490553.5996093750\tMAE: 0.0128698394\n",
296 | "47:\tELBO: -1281.41165\tMAPE: 14320261215561.0136718750\tMAE: 0.0128684057\n",
297 | "48:\tELBO: -1281.41165\tMAPE: 14320497270011.2441406250\tMAE: 0.0128679058\n",
298 | "49:\tELBO: -1281.41165\tMAPE: 14321674477509.0996093750\tMAE: 0.0128687225\n",
299 | "50:\tELBO: -1281.41166\tMAPE: 14320283412827.4941406250\tMAE: 0.0128677828\n",
300 | "51:\tELBO: -1281.41166\tMAPE: 14319639588014.8105468750\tMAE: 0.0128679143\n",
301 | "0:\tELBO: -519.10889\tMAPE: 52401036505250.9531250000\tMAE: 0.0310284024\n",
302 | "1:\tELBO: -633.20306\tMAPE: 75594345700699.0937500000\tMAE: 0.0517868061\n",
303 | "2:\tELBO: -781.60368\tMAPE: 49765990397264.1328125000\tMAE: 0.0442772133\n",
304 | "3:\tELBO: -948.99524\tMAPE: 12772076559101.4179687500\tMAE: 0.0244733302\n",
305 | "4:\tELBO: -1118.76881\tMAPE: 8871209368007.8886718750\tMAE: 0.0255990605\n",
306 | "5:\tELBO: -1192.57630\tMAPE: 9332732155548.2871093750\tMAE: 0.0171516561\n",
307 | "6:\tELBO: -1216.85734\tMAPE: 13489316548852.2265625000\tMAE: 0.0180606670\n",
308 | "7:\tELBO: -1247.09159\tMAPE: 14029028506229.9882812500\tMAE: 0.0183542029\n",
309 | "8:\tELBO: -1250.46601\tMAPE: 8548422546060.6787109375\tMAE: 0.0154875537\n",
310 | "9:\tELBO: -1254.93984\tMAPE: 9917052701377.8808593750\tMAE: 0.0167717834\n",
311 | "10:\tELBO: -1256.78971\tMAPE: 10606909624116.5117187500\tMAE: 0.0167994990\n",
312 | "11:\tELBO: -1260.05819\tMAPE: 10113472926837.5644531250\tMAE: 0.0169781542\n",
313 | "12:\tELBO: -1269.55030\tMAPE: 7774202734347.0976562500\tMAE: 0.0169858655\n",
314 | "13:\tELBO: -1275.53195\tMAPE: 7594326910074.1279296875\tMAE: 0.0165610521\n",
315 | "14:\tELBO: -1277.78128\tMAPE: 8214522174597.3056640625\tMAE: 0.0155452604\n",
316 | "15:\tELBO: -1277.91885\tMAPE: 7846963247494.2021484375\tMAE: 0.0152263034\n",
317 | "16:\tELBO: -1278.00928\tMAPE: 8118168043055.8271484375\tMAE: 0.0151699122\n",
318 | "17:\tELBO: -1278.04811\tMAPE: 8188392849514.3623046875\tMAE: 0.0151368852\n",
319 | "18:\tELBO: -1278.23309\tMAPE: 8417542385146.2900390625\tMAE: 0.0150634879\n",
320 | "19:\tELBO: -1278.44332\tMAPE: 8475261418360.4814453125\tMAE: 0.0150288240\n",
321 | "20:\tELBO: -1278.46029\tMAPE: 8558768456729.9277343750\tMAE: 0.0151721818\n",
322 | "21:\tELBO: -1278.64913\tMAPE: 8370573515239.9521484375\tMAE: 0.0151737020\n",
323 | "22:\tELBO: -1278.72524\tMAPE: 8105587130185.7587890625\tMAE: 0.0152331514\n",
324 | "23:\tELBO: -1278.74470\tMAPE: 7808122098852.6679687500\tMAE: 0.0152145013\n",
325 | "24:\tELBO: -1278.76687\tMAPE: 7802903571228.6884765625\tMAE: 0.0152452462\n",
326 | "25:\tELBO: -1278.79372\tMAPE: 7784845550469.1191406250\tMAE: 0.0152456413\n",
327 | "26:\tELBO: -1278.81643\tMAPE: 7707919102354.0146484375\tMAE: 0.0151943730\n",
328 | "27:\tELBO: -1278.84533\tMAPE: 7539073130139.2382812500\tMAE: 0.0151293445\n",
329 | "28:\tELBO: -1278.86449\tMAPE: 7224474698286.8066406250\tMAE: 0.0150343253\n",
330 | "29:\tELBO: -1278.87094\tMAPE: 7020631584527.4287109375\tMAE: 0.0149857593\n",
331 | "30:\tELBO: -1278.87187\tMAPE: 6892915789356.0253906250\tMAE: 0.0149543188\n",
332 | "31:\tELBO: -1278.87345\tMAPE: 6933931043809.7353515625\tMAE: 0.0149791947\n",
333 | "32:\tELBO: -1278.87372\tMAPE: 6953906355615.4541015625\tMAE: 0.0149931418\n",
334 | "33:\tELBO: -1278.87380\tMAPE: 6962555930031.3271484375\tMAE: 0.0149990226\n",
335 | "34:\tELBO: -1278.87391\tMAPE: 6972377957677.8408203125\tMAE: 0.0150041587\n",
336 | "35:\tELBO: -1278.87407\tMAPE: 6999747745278.0556640625\tMAE: 0.0150174678\n",
337 | "36:\tELBO: -1278.87433\tMAPE: 7013017865307.9316406250\tMAE: 0.0150206843\n",
338 | "37:\tELBO: -1278.87466\tMAPE: 7026939899197.2529296875\tMAE: 0.0150230092\n",
339 | "38:\tELBO: -1278.87483\tMAPE: 7035455598857.6240234375\tMAE: 0.0150361204\n",
340 | "39:\tELBO: -1278.87497\tMAPE: 7037492223955.2666015625\tMAE: 0.0150369856\n",
341 | "40:\tELBO: -1278.87500\tMAPE: 7030220487911.1103515625\tMAE: 0.0150336387\n",
342 | "41:\tELBO: -1278.87501\tMAPE: 7027909963403.5312500000\tMAE: 0.0150336282\n",
343 | "42:\tELBO: -1278.87502\tMAPE: 7025279937211.9804687500\tMAE: 0.0150345782\n",
344 | "43:\tELBO: -1278.87503\tMAPE: 7025079372694.2529296875\tMAE: 0.0150342980\n",
345 | "44:\tELBO: -1278.87503\tMAPE: 7026379800150.3271484375\tMAE: 0.0150343575\n",
346 | "0:\tELBO: -519.79241\tMAPE: 47094701512096.9765625000\tMAE: 0.0419916225\n",
347 | "1:\tELBO: -642.60579\tMAPE: 65209310769057.3671875000\tMAE: 0.0625903159\n",
348 | "2:\tELBO: -797.17138\tMAPE: 27188032643587.4648437500\tMAE: 0.0409386327\n",
349 | "3:\tELBO: -954.28472\tMAPE: 3589493681797.2075195312\tMAE: 0.0241015363\n",
350 | "4:\tELBO: -1066.91194\tMAPE: 6654418656405.2246093750\tMAE: 0.0315169131\n",
351 | "5:\tELBO: -1169.14704\tMAPE: 6929709507510.4189453125\tMAE: 0.0234854739\n",
352 | "6:\tELBO: -1195.81102\tMAPE: 10793785690642.3535156250\tMAE: 0.0211854408\n",
353 | "7:\tELBO: -1232.85035\tMAPE: 9968134877714.1660156250\tMAE: 0.0203104921\n",
354 | "8:\tELBO: -1248.11099\tMAPE: 5674621747493.0029296875\tMAE: 0.0169311901\n",
355 | "9:\tELBO: -1251.53278\tMAPE: 4131329407493.2729492188\tMAE: 0.0165374084\n",
356 | "10:\tELBO: -1255.01583\tMAPE: 3696809370498.4003906250\tMAE: 0.0162317975\n",
357 | "11:\tELBO: -1264.56125\tMAPE: 3744568768250.2202148438\tMAE: 0.0165683510\n",
358 | "12:\tELBO: -1271.62490\tMAPE: 4022376521002.9277343750\tMAE: 0.0162262307\n",
359 | "13:\tELBO: -1275.63263\tMAPE: 3389455891615.2514648438\tMAE: 0.0156273497\n",
360 | "14:\tELBO: -1276.47830\tMAPE: 4282307468719.4614257812\tMAE: 0.0154584982\n",
361 | "15:\tELBO: -1276.76596\tMAPE: 3738680122564.9448242188\tMAE: 0.0152270809\n",
362 | "16:\tELBO: -1276.79099\tMAPE: 3614780996295.5629882812\tMAE: 0.0151782369\n",
363 | "17:\tELBO: -1276.80484\tMAPE: 3647527261675.1333007812\tMAE: 0.0151561136\n",
364 | "18:\tELBO: -1276.86540\tMAPE: 3832416315064.9487304688\tMAE: 0.0150273220\n",
365 | "19:\tELBO: -1276.92356\tMAPE: 3756266476527.6220703125\tMAE: 0.0149830958\n",
366 | "20:\tELBO: -1276.93491\tMAPE: 3040474971349.3574218750\tMAE: 0.0147628656\n",
367 | "21:\tELBO: -1277.06843\tMAPE: 3133978412293.5195312500\tMAE: 0.0147638449\n",
368 | "22:\tELBO: -1277.10299\tMAPE: 3029906214031.9760742188\tMAE: 0.0147212688\n",
369 | "23:\tELBO: -1277.12230\tMAPE: 2839962353927.0424804688\tMAE: 0.0146185856\n",
370 | "24:\tELBO: -1277.12528\tMAPE: 2701675392612.4130859375\tMAE: 0.0145803636\n",
371 | "25:\tELBO: -1277.13332\tMAPE: 2609166948870.6406250000\tMAE: 0.0145150892\n",
372 | "26:\tELBO: -1277.13467\tMAPE: 2570988354208.7988281250\tMAE: 0.0144889899\n",
373 | "27:\tELBO: -1277.13593\tMAPE: 2556314874755.4677734375\tMAE: 0.0144777154\n",
374 | "28:\tELBO: -1277.14045\tMAPE: 2526059441418.5654296875\tMAE: 0.0144427428\n",
375 | "29:\tELBO: -1277.14629\tMAPE: 2431786204992.2534179688\tMAE: 0.0143633902\n",
376 | "30:\tELBO: -1277.15699\tMAPE: 2425708078729.8173828125\tMAE: 0.0143046709\n",
377 | "31:\tELBO: -1277.18262\tMAPE: 2398831526838.1914062500\tMAE: 0.0142649993\n",
378 | "32:\tELBO: -1277.18945\tMAPE: 2288610388353.8764648438\tMAE: 0.0142453791\n",
379 | "33:\tELBO: -1277.20083\tMAPE: 2364219245299.2358398438\tMAE: 0.0142814754\n",
380 | "34:\tELBO: -1277.20205\tMAPE: 2414128699130.0214843750\tMAE: 0.0143121317\n"
381 | ]
382 | },
383 | {
384 | "name": "stdout",
385 | "output_type": "stream",
386 | "text": [
387 | "35:\tELBO: -1277.20296\tMAPE: 2491869431191.7231445312\tMAE: 0.0143483481\n",
388 | "36:\tELBO: -1277.20305\tMAPE: 2500473875804.3549804688\tMAE: 0.0143436384\n",
389 | "37:\tELBO: -1277.20314\tMAPE: 2517329376811.9497070312\tMAE: 0.0143466934\n",
390 | "38:\tELBO: -1277.20316\tMAPE: 2524545864282.4096679688\tMAE: 0.0143495128\n",
391 | "39:\tELBO: -1277.20317\tMAPE: 2524239721699.1132812500\tMAE: 0.0143492165\n",
392 | "40:\tELBO: -1277.20321\tMAPE: 2526145043672.5214843750\tMAE: 0.0143530467\n",
393 | "41:\tELBO: -1277.20321\tMAPE: 2529316333127.9990234375\tMAE: 0.0143545050\n",
394 | "42:\tELBO: -1277.20322\tMAPE: 2531252419162.5708007812\tMAE: 0.0143580824\n",
395 | "43:\tELBO: -1277.20323\tMAPE: 2533650810371.6552734375\tMAE: 0.0143604720\n",
396 | "0:\tELBO: -521.08905\tMAPE: 68019517970247.7968750000\tMAE: 0.0537527602\n",
397 | "1:\tELBO: -665.51654\tMAPE: 90936912331365.0156250000\tMAE: 0.0684862582\n",
398 | "2:\tELBO: -842.00021\tMAPE: 3459708601475.0395507812\tMAE: 0.0273101718\n",
399 | "3:\tELBO: -1008.64442\tMAPE: 31207104876703.4257812500\tMAE: 0.0258855175\n",
400 | "4:\tELBO: -1176.74560\tMAPE: 4191272739107.7285156250\tMAE: 0.0243821626\n",
401 | "5:\tELBO: -1215.21810\tMAPE: 7090470852771.8310546875\tMAE: 0.0201584682\n",
402 | "6:\tELBO: -1227.80443\tMAPE: 9405309749135.5566406250\tMAE: 0.0183328692\n",
403 | "7:\tELBO: -1236.67358\tMAPE: 8547768295646.7968750000\tMAE: 0.0178920828\n",
404 | "8:\tELBO: -1243.40004\tMAPE: 7091426344815.4150390625\tMAE: 0.0189101261\n",
405 | "9:\tELBO: -1250.42017\tMAPE: 8279727934042.0380859375\tMAE: 0.0181506682\n",
406 | "10:\tELBO: -1264.22089\tMAPE: 10276628809293.3437500000\tMAE: 0.0170432312\n",
407 | "11:\tELBO: -1271.91345\tMAPE: 11299430505479.5175781250\tMAE: 0.0165148700\n",
408 | "12:\tELBO: -1274.75487\tMAPE: 12364209640759.1835937500\tMAE: 0.0164789499\n",
409 | "13:\tELBO: -1275.53123\tMAPE: 12639778014502.7128906250\tMAE: 0.0164599607\n",
410 | "14:\tELBO: -1276.46001\tMAPE: 12920335111939.4824218750\tMAE: 0.0164411719\n",
411 | "15:\tELBO: -1277.09391\tMAPE: 12831662182964.8183593750\tMAE: 0.0166061050\n",
412 | "16:\tELBO: -1277.71772\tMAPE: 12608087502809.1542968750\tMAE: 0.0166674013\n",
413 | "17:\tELBO: -1277.82825\tMAPE: 12596185463566.3574218750\tMAE: 0.0166148066\n",
414 | "18:\tELBO: -1277.90408\tMAPE: 12644198659669.3105468750\tMAE: 0.0166189210\n",
415 | "19:\tELBO: -1277.95040\tMAPE: 12734042337286.3125000000\tMAE: 0.0166448879\n",
416 | "20:\tELBO: -1277.96971\tMAPE: 12851836331166.3691406250\tMAE: 0.0165852806\n",
417 | "21:\tELBO: -1278.00418\tMAPE: 12880132816083.6054687500\tMAE: 0.0166216619\n",
418 | "22:\tELBO: -1278.01380\tMAPE: 12937847915983.5683593750\tMAE: 0.0166392628\n",
419 | "23:\tELBO: -1278.02038\tMAPE: 12925317513171.9570312500\tMAE: 0.0166389285\n",
420 | "24:\tELBO: -1278.02973\tMAPE: 12843399846781.7792968750\tMAE: 0.0166090923\n",
421 | "25:\tELBO: -1278.04166\tMAPE: 12715281529270.6757812500\tMAE: 0.0165750329\n",
422 | "26:\tELBO: -1278.06323\tMAPE: 12318200112652.2011718750\tMAE: 0.0164612932\n",
423 | "27:\tELBO: -1278.08433\tMAPE: 11635272367836.6191406250\tMAE: 0.0162332651\n",
424 | "28:\tELBO: -1278.10184\tMAPE: 10892129773702.4765625000\tMAE: 0.0159915966\n",
425 | "29:\tELBO: -1278.10815\tMAPE: 10792966925400.6933593750\tMAE: 0.0159551961\n",
426 | "30:\tELBO: -1278.11196\tMAPE: 10775496335747.7226562500\tMAE: 0.0159158619\n",
427 | "31:\tELBO: -1278.11402\tMAPE: 10790790707871.8730468750\tMAE: 0.0158953731\n",
428 | "32:\tELBO: -1278.11654\tMAPE: 10798620458753.7753906250\tMAE: 0.0158531027\n",
429 | "33:\tELBO: -1278.11692\tMAPE: 10789404317704.1308593750\tMAE: 0.0158484178\n",
430 | "34:\tELBO: -1278.11733\tMAPE: 10778523172531.1191406250\tMAE: 0.0158385837\n",
431 | "35:\tELBO: -1278.11754\tMAPE: 10763967846259.1503906250\tMAE: 0.0158360727\n",
432 | "36:\tELBO: -1278.11766\tMAPE: 10747661469257.8261718750\tMAE: 0.0158330122\n",
433 | "37:\tELBO: -1278.11776\tMAPE: 10724779154076.9824218750\tMAE: 0.0158295045\n",
434 | "38:\tELBO: -1278.11788\tMAPE: 10705173906219.7070312500\tMAE: 0.0158241812\n",
435 | "39:\tELBO: -1278.11805\tMAPE: 10665500709689.7675781250\tMAE: 0.0158112115\n",
436 | "40:\tELBO: -1278.11812\tMAPE: 10634520998128.3886718750\tMAE: 0.0157959329\n",
437 | "41:\tELBO: -1278.11820\tMAPE: 10652741048641.3632812500\tMAE: 0.0158007227\n",
438 | "42:\tELBO: -1278.11822\tMAPE: 10657740484224.0839843750\tMAE: 0.0158010413\n",
439 | "43:\tELBO: -1278.11823\tMAPE: 10660779733431.6894531250\tMAE: 0.0158006113\n",
440 | "44:\tELBO: -1278.11825\tMAPE: 10658784850593.0800781250\tMAE: 0.0158005167\n",
441 | "45:\tELBO: -1278.11827\tMAPE: 10661207203585.9199218750\tMAE: 0.0158032696\n",
442 | "46:\tELBO: -1278.11828\tMAPE: 10659657267555.9550781250\tMAE: 0.0158044758\n",
443 | "47:\tELBO: -1278.11829\tMAPE: 10657757162064.8984375000\tMAE: 0.0158055240\n",
444 | "0:\tELBO: -522.15367\tMAPE: 82235191592533.6406250000\tMAE: 0.0639131025\n",
445 | "1:\tELBO: -720.13134\tMAPE: 87703268280454.0468750000\tMAE: 0.0638746602\n",
446 | "2:\tELBO: -872.21950\tMAPE: 18586644932695.9804687500\tMAE: 0.0316589185\n",
447 | "3:\tELBO: -979.01129\tMAPE: 28410843288486.3789062500\tMAE: 0.0219384607\n",
448 | "4:\tELBO: -1198.05281\tMAPE: 16127058737849.8007812500\tMAE: 0.0215174370\n",
449 | "5:\tELBO: -1218.19624\tMAPE: 17519488420925.2675781250\tMAE: 0.0236774149\n",
450 | "6:\tELBO: -1227.20854\tMAPE: 18781783377051.6484375000\tMAE: 0.0219471771\n",
451 | "7:\tELBO: -1231.52951\tMAPE: 19889784649207.1015625000\tMAE: 0.0206980995\n",
452 | "8:\tELBO: -1236.80313\tMAPE: 21109932964091.4765625000\tMAE: 0.0200537201\n",
453 | "9:\tELBO: -1241.41854\tMAPE: 33905295548058.7500000000\tMAE: 0.0225442010\n",
454 | "10:\tELBO: -1269.85781\tMAPE: 25795146136540.0156250000\tMAE: 0.0181581170\n",
455 | "11:\tELBO: -1275.19789\tMAPE: 25755449885608.7031250000\tMAE: 0.0174531844\n",
456 | "12:\tELBO: -1278.25337\tMAPE: 25698624124701.3710937500\tMAE: 0.0174796440\n",
457 | "13:\tELBO: -1279.26456\tMAPE: 27939553987829.5742187500\tMAE: 0.0177005515\n",
458 | "14:\tELBO: -1279.92458\tMAPE: 27053834370817.0976562500\tMAE: 0.0183809493\n",
459 | "15:\tELBO: -1281.05066\tMAPE: 26612924212027.3281250000\tMAE: 0.0177251352\n",
460 | "16:\tELBO: -1281.26608\tMAPE: 26956740950505.6914062500\tMAE: 0.0176698139\n",
461 | "17:\tELBO: -1281.29782\tMAPE: 27141577473092.7031250000\tMAE: 0.0179079694\n",
462 | "18:\tELBO: -1281.48008\tMAPE: 27301243536679.5742187500\tMAE: 0.0177173199\n",
463 | "19:\tELBO: -1281.52672\tMAPE: 27238878998116.9140625000\tMAE: 0.0176826128\n",
464 | "20:\tELBO: -1281.57470\tMAPE: 27115251198371.3281250000\tMAE: 0.0176872310\n",
465 | "21:\tELBO: -1281.59759\tMAPE: 27104010190070.4335937500\tMAE: 0.0176549666\n",
466 | "22:\tELBO: -1281.60637\tMAPE: 26966284968426.0507812500\tMAE: 0.0177663000\n",
467 | "23:\tELBO: -1281.61957\tMAPE: 26998671179618.4218750000\tMAE: 0.0176809640\n",
468 | "24:\tELBO: -1281.62431\tMAPE: 26988619986260.7343750000\tMAE: 0.0176470423\n",
469 | "25:\tELBO: -1281.63017\tMAPE: 26923092082453.2148437500\tMAE: 0.0176034375\n",
470 | "26:\tELBO: -1281.63898\tMAPE: 26722676485388.6718750000\tMAE: 0.0175519992\n",
471 | "27:\tELBO: -1281.65547\tMAPE: 26212587042302.2109375000\tMAE: 0.0173928576\n",
472 | "28:\tELBO: -1281.65880\tMAPE: 25976634600511.3085937500\tMAE: 0.0173225587\n",
473 | "29:\tELBO: -1281.66094\tMAPE: 25754778675314.3515625000\tMAE: 0.0172685540\n",
474 | "30:\tELBO: -1281.66164\tMAPE: 25724765054535.5507812500\tMAE: 0.0172607472\n",
475 | "31:\tELBO: -1281.66678\tMAPE: 25470593648738.1093750000\tMAE: 0.0171813441\n",
476 | "32:\tELBO: -1281.66719\tMAPE: 25468228314825.3632812500\tMAE: 0.0171834748\n",
477 | "33:\tELBO: -1281.66795\tMAPE: 25460164558859.2109375000\tMAE: 0.0171755328\n",
478 | "34:\tELBO: -1281.66885\tMAPE: 25440103095175.9960937500\tMAE: 0.0171600085\n",
479 | "35:\tELBO: -1281.66986\tMAPE: 25391634283840.1093750000\tMAE: 0.0171391825\n",
480 | "36:\tELBO: -1281.67101\tMAPE: 25300979475623.0351562500\tMAE: 0.0171173608\n",
481 | "37:\tELBO: -1281.67134\tMAPE: 25239000947643.3398437500\tMAE: 0.0170940922\n",
482 | "38:\tELBO: -1281.67209\tMAPE: 25136598562973.3359375000\tMAE: 0.0170810487\n",
483 | "39:\tELBO: -1281.67228\tMAPE: 25112200997096.2343750000\tMAE: 0.0170877895\n",
484 | "40:\tELBO: -1281.67234\tMAPE: 25111397172495.9726562500\tMAE: 0.0170895016\n",
485 | "41:\tELBO: -1281.67237\tMAPE: 25115574770986.0273437500\tMAE: 0.0170914068\n",
486 | "42:\tELBO: -1281.67245\tMAPE: 25120737710573.9453125000\tMAE: 0.0170996461\n",
487 | "43:\tELBO: -1281.67249\tMAPE: 25089212563686.7656250000\tMAE: 0.0170963778\n",
488 | "44:\tELBO: -1281.67254\tMAPE: 25094720451288.3906250000\tMAE: 0.0170954273\n",
489 | "45:\tELBO: -1281.67254\tMAPE: 25091943563790.7343750000\tMAE: 0.0170951526\n",
490 | "46:\tELBO: -1281.67255\tMAPE: 25085533633260.3164062500\tMAE: 0.0170945075\n",
491 | "47:\tELBO: -1281.67255\tMAPE: 25083622582924.5117187500\tMAE: 0.0170949973\n",
492 | "48:\tELBO: -1281.67255\tMAPE: 25083547129635.8750000000\tMAE: 0.0170951321\n",
493 | "0:\tELBO: -522.51681\tMAPE: 69078328074036.9218750000\tMAE: 0.0678500818\n",
494 | "1:\tELBO: -817.05799\tMAPE: 18089489038067.8593750000\tMAE: 0.0313402229\n",
495 | "2:\tELBO: -909.07750\tMAPE: 41165306429346.3828125000\tMAE: 0.0218203564\n",
496 | "3:\tELBO: -1150.35684\tMAPE: 22284681188393.8515625000\tMAE: 0.0213170841\n",
497 | "4:\tELBO: -1228.55178\tMAPE: 29702625080826.1835937500\tMAE: 0.0171755223\n",
498 | "5:\tELBO: -1250.93732\tMAPE: 30444243894869.4726562500\tMAE: 0.0184554905\n",
499 | "6:\tELBO: -1269.42641\tMAPE: 28361893435156.5859375000\tMAE: 0.0185696993\n",
500 | "7:\tELBO: -1274.81913\tMAPE: 28935329762654.3867187500\tMAE: 0.0181609909\n",
501 | "8:\tELBO: -1283.62115\tMAPE: 33175985693652.7031250000\tMAE: 0.0181439096\n"
502 | ]
503 | },
504 | {
505 | "name": "stdout",
506 | "output_type": "stream",
507 | "text": [
508 | "9:\tELBO: -1284.04337\tMAPE: 34579609470771.3750000000\tMAE: 0.0166230177\n",
509 | "10:\tELBO: -1285.04661\tMAPE: 33574613956859.9882812500\tMAE: 0.0171819991\n",
510 | "11:\tELBO: -1285.27593\tMAPE: 33148028469392.8476562500\tMAE: 0.0171401894\n",
511 | "12:\tELBO: -1285.72107\tMAPE: 32195159015656.6640625000\tMAE: 0.0168901387\n",
512 | "13:\tELBO: -1285.98997\tMAPE: 32172562807961.9375000000\tMAE: 0.0169141533\n",
513 | "14:\tELBO: -1286.29218\tMAPE: 31796270688407.5351562500\tMAE: 0.0169242702\n",
514 | "15:\tELBO: -1286.48071\tMAPE: 32253480303128.7851562500\tMAE: 0.0169294422\n",
515 | "16:\tELBO: -1286.53306\tMAPE: 32534170601127.3476562500\tMAE: 0.0168796419\n",
516 | "17:\tELBO: -1286.54933\tMAPE: 32669192992922.8398437500\tMAE: 0.0169687413\n",
517 | "18:\tELBO: -1286.57732\tMAPE: 32531475290535.0781250000\tMAE: 0.0169305335\n",
518 | "19:\tELBO: -1286.59463\tMAPE: 32375366356853.6992187500\tMAE: 0.0168943717\n",
519 | "20:\tELBO: -1286.60650\tMAPE: 32242135353661.4023437500\tMAE: 0.0168647321\n",
520 | "21:\tELBO: -1286.60943\tMAPE: 32246556005226.4101562500\tMAE: 0.0168397106\n",
521 | "22:\tELBO: -1286.61941\tMAPE: 32134283897761.3164062500\tMAE: 0.0168163106\n",
522 | "23:\tELBO: -1286.62701\tMAPE: 32027850229477.3906250000\tMAE: 0.0167808170\n",
523 | "24:\tELBO: -1286.63348\tMAPE: 31883240891545.6992187500\tMAE: 0.0167304074\n",
524 | "25:\tELBO: -1286.64195\tMAPE: 31432348367427.1484375000\tMAE: 0.0165708321\n",
525 | "26:\tELBO: -1286.64885\tMAPE: 31200686424798.1835937500\tMAE: 0.0164900658\n",
526 | "27:\tELBO: -1286.65087\tMAPE: 30961638977821.1835937500\tMAE: 0.0164140162\n",
527 | "28:\tELBO: -1286.65176\tMAPE: 30874116376024.5742187500\tMAE: 0.0163943085\n",
528 | "29:\tELBO: -1286.65389\tMAPE: 30711586513603.4609375000\tMAE: 0.0163457104\n",
529 | "30:\tELBO: -1286.65458\tMAPE: 30574664541874.1914062500\tMAE: 0.0163514769\n",
530 | "31:\tELBO: -1286.65619\tMAPE: 30660203371464.7851562500\tMAE: 0.0163563914\n",
531 | "32:\tELBO: -1286.65697\tMAPE: 30715227145000.8632812500\tMAE: 0.0163651981\n",
532 | "33:\tELBO: -1286.65755\tMAPE: 30748782441186.6523437500\tMAE: 0.0163691488\n",
533 | "34:\tELBO: -1286.65780\tMAPE: 30661496353765.8867187500\tMAE: 0.0163454602\n",
534 | "35:\tELBO: -1286.65830\tMAPE: 30694997422071.6601562500\tMAE: 0.0163444472\n",
535 | "36:\tELBO: -1286.65841\tMAPE: 30681277161309.7109375000\tMAE: 0.0163360131\n",
536 | "37:\tELBO: -1286.65850\tMAPE: 30665164725497.8710937500\tMAE: 0.0163266644\n",
537 | "38:\tELBO: -1286.65872\tMAPE: 30657403061803.4531250000\tMAE: 0.0163171194\n",
538 | "39:\tELBO: -1286.65898\tMAPE: 30686891109585.6289062500\tMAE: 0.0163010298\n",
539 | "40:\tELBO: -1286.65945\tMAPE: 30696249952440.0468750000\tMAE: 0.0162997012\n",
540 | "41:\tELBO: -1286.65958\tMAPE: 30726702821307.9531250000\tMAE: 0.0163112760\n",
541 | "42:\tELBO: -1286.65968\tMAPE: 30780973443715.9414062500\tMAE: 0.0163289970\n",
542 | "43:\tELBO: -1286.65973\tMAPE: 30812321363306.3867187500\tMAE: 0.0163365478\n",
543 | "44:\tELBO: -1286.65979\tMAPE: 30829760425579.1093750000\tMAE: 0.0163423517\n",
544 | "45:\tELBO: -1286.65990\tMAPE: 30869644431424.8359375000\tMAE: 0.0163528136\n",
545 | "46:\tELBO: -1286.65990\tMAPE: 30865364808791.4101562500\tMAE: 0.0163527014\n",
546 | "47:\tELBO: -1286.65990\tMAPE: 30862256531520.5898437500\tMAE: 0.0163516420\n",
547 | "0:\tELBO: -522.19575\tMAPE: 52916303506929.7500000000\tMAE: 0.0642678709\n",
548 | "1:\tELBO: -800.71632\tMAPE: 76230939384819.5156250000\tMAE: 0.0296725762\n",
549 | "2:\tELBO: -925.59451\tMAPE: 35659161434789.2578125000\tMAE: 0.0274146862\n",
550 | "3:\tELBO: -1140.61742\tMAPE: 26283418479485.2851562500\tMAE: 0.0233124361\n",
551 | "4:\tELBO: -1186.13060\tMAPE: 26160010070625.8242187500\tMAE: 0.0226272759\n",
552 | "5:\tELBO: -1221.98815\tMAPE: 30967492842035.0390625000\tMAE: 0.0210596461\n",
553 | "6:\tELBO: -1239.35990\tMAPE: 31846287040285.5273437500\tMAE: 0.0191973200\n",
554 | "7:\tELBO: -1244.98609\tMAPE: 31353444892679.9101562500\tMAE: 0.0188618518\n",
555 | "8:\tELBO: -1253.64280\tMAPE: 30468929171056.0976562500\tMAE: 0.0178812455\n",
556 | "9:\tELBO: -1269.79773\tMAPE: 30011173514555.6718750000\tMAE: 0.0170534604\n",
557 | "10:\tELBO: -1283.59351\tMAPE: 33091560684323.1210937500\tMAE: 0.0193294718\n",
558 | "11:\tELBO: -1286.43802\tMAPE: 34475455320082.9765625000\tMAE: 0.0204770568\n",
559 | "12:\tELBO: -1287.76528\tMAPE: 34138372471640.3398437500\tMAE: 0.0204254346\n",
560 | "13:\tELBO: -1289.22813\tMAPE: 33442020312216.8906250000\tMAE: 0.0199827274\n",
561 | "14:\tELBO: -1289.97658\tMAPE: 33052283073932.3125000000\tMAE: 0.0197176757\n",
562 | "15:\tELBO: -1290.40340\tMAPE: 33227801770983.4375000000\tMAE: 0.0198406406\n",
563 | "16:\tELBO: -1290.52901\tMAPE: 33167441229567.1289062500\tMAE: 0.0198448836\n",
564 | "17:\tELBO: -1290.66717\tMAPE: 32918253987120.8867187500\tMAE: 0.0197861913\n",
565 | "18:\tELBO: -1290.69447\tMAPE: 32894631606781.6914062500\tMAE: 0.0197366797\n",
566 | "19:\tELBO: -1290.73168\tMAPE: 32819471490179.8398437500\tMAE: 0.0197596083\n",
567 | "20:\tELBO: -1290.75021\tMAPE: 32645375876971.1093750000\tMAE: 0.0197367539\n",
568 | "21:\tELBO: -1290.76733\tMAPE: 32640360617410.6406250000\tMAE: 0.0197671108\n",
569 | "22:\tELBO: -1290.77288\tMAPE: 32594640705344.3281250000\tMAE: 0.0197661754\n",
570 | "23:\tELBO: -1290.77805\tMAPE: 32487237349728.8242187500\tMAE: 0.0197599903\n",
571 | "24:\tELBO: -1290.78194\tMAPE: 32398791330484.6757812500\tMAE: 0.0197530030\n",
572 | "25:\tELBO: -1290.78546\tMAPE: 32264673753876.9218750000\tMAE: 0.0197465922\n",
573 | "26:\tELBO: -1290.79115\tMAPE: 32011506714985.0234375000\tMAE: 0.0197302808\n",
574 | "27:\tELBO: -1290.79613\tMAPE: 31388569448274.4140625000\tMAE: 0.0196035804\n",
575 | "28:\tELBO: -1290.80243\tMAPE: 31097496509166.7617187500\tMAE: 0.0195742320\n",
576 | "29:\tELBO: -1290.80358\tMAPE: 31181195506867.3906250000\tMAE: 0.0195806747\n",
577 | "30:\tELBO: -1290.80610\tMAPE: 31330034715442.4648437500\tMAE: 0.0195769581\n",
578 | "31:\tELBO: -1290.80892\tMAPE: 31432572227553.7500000000\tMAE: 0.0195665715\n",
579 | "32:\tELBO: -1290.81583\tMAPE: 31632309058478.3593750000\tMAE: 0.0195722995\n",
580 | "33:\tELBO: -1290.81910\tMAPE: 31811774938787.7500000000\tMAE: 0.0195765518\n",
581 | "34:\tELBO: -1290.82296\tMAPE: 31790127127438.3906250000\tMAE: 0.0195878731\n",
582 | "35:\tELBO: -1290.82491\tMAPE: 31722295771129.2539062500\tMAE: 0.0195993294\n",
583 | "36:\tELBO: -1290.82517\tMAPE: 31713405121448.2656250000\tMAE: 0.0195960443\n",
584 | "37:\tELBO: -1290.82544\tMAPE: 31702423010272.0117187500\tMAE: 0.0195949671\n",
585 | "38:\tELBO: -1290.82587\tMAPE: 31659820198845.7968750000\tMAE: 0.0195709962\n",
586 | "39:\tELBO: -1290.82597\tMAPE: 31676239332979.1250000000\tMAE: 0.0195641376\n",
587 | "40:\tELBO: -1290.82598\tMAPE: 31682792680188.2500000000\tMAE: 0.0195618603\n",
588 | "41:\tELBO: -1290.82599\tMAPE: 31687501236252.8867187500\tMAE: 0.0195621234\n",
589 | "42:\tELBO: -1290.82599\tMAPE: 31689843454449.4492187500\tMAE: 0.0195624742\n",
590 | "0:\tELBO: -521.35444\tMAPE: 35801917096802.5390625000\tMAE: 0.0571997696\n",
591 | "1:\tELBO: -715.64735\tMAPE: 60777878988888.5546875000\tMAE: 0.0307179066\n",
592 | "2:\tELBO: -829.39478\tMAPE: 51554651490749.2421875000\tMAE: 0.0357499043\n",
593 | "3:\tELBO: -988.11448\tMAPE: 35863386516214.4062500000\tMAE: 0.0334929253\n",
594 | "4:\tELBO: -1201.43288\tMAPE: 30976222394326.7734375000\tMAE: 0.0245240604\n",
595 | "5:\tELBO: -1222.37465\tMAPE: 30483047820516.2539062500\tMAE: 0.0244531090\n",
596 | "6:\tELBO: -1234.90345\tMAPE: 32516287513520.2539062500\tMAE: 0.0263577711\n",
597 | "7:\tELBO: -1240.98195\tMAPE: 31941418432810.1601562500\tMAE: 0.0250581895\n",
598 | "8:\tELBO: -1246.77259\tMAPE: 30852217002416.6484375000\tMAE: 0.0244405822\n",
599 | "9:\tELBO: -1263.68048\tMAPE: 26027733455792.0117187500\tMAE: 0.0219368620\n",
600 | "10:\tELBO: -1274.09728\tMAPE: 25239776402125.0039062500\tMAE: 0.0217237020\n",
601 | "11:\tELBO: -1286.44312\tMAPE: 27467162305070.1835937500\tMAE: 0.0239303031\n",
602 | "12:\tELBO: -1288.96289\tMAPE: 29836142856237.5625000000\tMAE: 0.0254476524\n",
603 | "13:\tELBO: -1289.80663\tMAPE: 30238414135542.4023437500\tMAE: 0.0258952701\n",
604 | "14:\tELBO: -1290.76937\tMAPE: 30233851114979.1250000000\tMAE: 0.0259971088\n",
605 | "15:\tELBO: -1291.12443\tMAPE: 29775423431162.8515625000\tMAE: 0.0256080684\n",
606 | "16:\tELBO: -1291.79750\tMAPE: 28906762581782.6015625000\tMAE: 0.0251480729\n",
607 | "17:\tELBO: -1291.94640\tMAPE: 28740762897558.4843750000\tMAE: 0.0251351991\n",
608 | "18:\tELBO: -1292.05143\tMAPE: 28564298672965.9140625000\tMAE: 0.0250854416\n",
609 | "19:\tELBO: -1292.10701\tMAPE: 28472106721586.0898437500\tMAE: 0.0252241695\n",
610 | "20:\tELBO: -1292.14510\tMAPE: 28440601438754.5625000000\tMAE: 0.0252218768\n",
611 | "21:\tELBO: -1292.17198\tMAPE: 28355150752966.1601562500\tMAE: 0.0252316372\n",
612 | "22:\tELBO: -1292.18151\tMAPE: 28271582853568.0976562500\tMAE: 0.0252277743\n",
613 | "23:\tELBO: -1292.18404\tMAPE: 28221325664309.5468750000\tMAE: 0.0251959886\n",
614 | "24:\tELBO: -1292.19276\tMAPE: 28092234015845.3632812500\tMAE: 0.0251747251\n",
615 | "25:\tELBO: -1292.19990\tMAPE: 27990464278095.7343750000\tMAE: 0.0251463734\n",
616 | "26:\tELBO: -1292.20967\tMAPE: 27891166906012.6757812500\tMAE: 0.0251215913\n",
617 | "27:\tELBO: -1292.22446\tMAPE: 27788158062065.6015625000\tMAE: 0.0250971555\n",
618 | "28:\tELBO: -1292.23837\tMAPE: 27708383019611.9335937500\tMAE: 0.0250906095\n",
619 | "29:\tELBO: -1292.25449\tMAPE: 27622297494522.4882812500\tMAE: 0.0251349816\n",
620 | "30:\tELBO: -1292.25734\tMAPE: 27664963250932.3984375000\tMAE: 0.0251909087\n",
621 | "31:\tELBO: -1292.26112\tMAPE: 27721126645249.6093750000\tMAE: 0.0251936833\n",
622 | "32:\tELBO: -1292.26535\tMAPE: 27848825869927.2500000000\tMAE: 0.0251907188\n"
623 | ]
624 | },
625 | {
626 | "name": "stdout",
627 | "output_type": "stream",
628 | "text": [
629 | "33:\tELBO: -1292.26736\tMAPE: 27911547138378.1875000000\tMAE: 0.0251778071\n",
630 | "34:\tELBO: -1292.26752\tMAPE: 27916307742457.7148437500\tMAE: 0.0251761525\n",
631 | "35:\tELBO: -1292.26784\tMAPE: 27919651561613.9375000000\tMAE: 0.0251660100\n",
632 | "36:\tELBO: -1292.26789\tMAPE: 27905783991132.9218750000\tMAE: 0.0251625553\n",
633 | "37:\tELBO: -1292.26793\tMAPE: 27886630284192.8242187500\tMAE: 0.0251597554\n",
634 | "38:\tELBO: -1292.26800\tMAPE: 27851570831452.5625000000\tMAE: 0.0251541377\n",
635 | "39:\tELBO: -1292.26812\tMAPE: 27799530143396.3007812500\tMAE: 0.0251475663\n",
636 | "40:\tELBO: -1292.26823\tMAPE: 27737777633593.5000000000\tMAE: 0.0251396172\n",
637 | "41:\tELBO: -1292.26839\tMAPE: 27681596721409.9218750000\tMAE: 0.0251332861\n",
638 | "42:\tELBO: -1292.26843\tMAPE: 27689487680890.6718750000\tMAE: 0.0251356992\n",
639 | "43:\tELBO: -1292.26843\tMAPE: 27694614935486.7382812500\tMAE: 0.0251352941\n",
640 | "44:\tELBO: -1292.26843\tMAPE: 27695291540577.0351562500\tMAE: 0.0251359912\n",
641 | "0:\tELBO: -520.12121\tMAPE: 19085964693587.7656250000\tMAE: 0.0464643943\n",
642 | "1:\tELBO: -665.83758\tMAPE: 63045907583452.5078125000\tMAE: 0.0377210679\n",
643 | "2:\tELBO: -799.60385\tMAPE: 63041132859630.1562500000\tMAE: 0.0478662777\n",
644 | "3:\tELBO: -947.16478\tMAPE: 49415743199987.6093750000\tMAE: 0.0480970568\n",
645 | "4:\tELBO: -1224.42348\tMAPE: 30272679939671.7734375000\tMAE: 0.0274745778\n",
646 | "5:\tELBO: -1236.27698\tMAPE: 29345009171478.5781250000\tMAE: 0.0279951407\n",
647 | "6:\tELBO: -1242.64574\tMAPE: 29834794063691.5039062500\tMAE: 0.0281846947\n",
648 | "7:\tELBO: -1246.22364\tMAPE: 28813335014028.5273437500\tMAE: 0.0279354804\n",
649 | "8:\tELBO: -1258.23751\tMAPE: 24094018340668.1757812500\tMAE: 0.0248706105\n",
650 | "9:\tELBO: -1266.00660\tMAPE: 20518161992054.3867187500\tMAE: 0.0230859344\n",
651 | "10:\tELBO: -1274.99134\tMAPE: 19764426642352.0273437500\tMAE: 0.0228396701\n",
652 | "11:\tELBO: -1286.18341\tMAPE: 20236180709791.2968750000\tMAE: 0.0239911448\n",
653 | "12:\tELBO: -1287.91993\tMAPE: 23415055835568.8085937500\tMAE: 0.0262219558\n",
654 | "13:\tELBO: -1289.47993\tMAPE: 22960037468117.4570312500\tMAE: 0.0259184993\n",
655 | "14:\tELBO: -1289.93591\tMAPE: 22706931938170.9140625000\tMAE: 0.0259055732\n",
656 | "15:\tELBO: -1290.78481\tMAPE: 22142873579127.2890625000\tMAE: 0.0257428495\n",
657 | "16:\tELBO: -1290.95953\tMAPE: 21834400320442.6406250000\tMAE: 0.0253635754\n",
658 | "17:\tELBO: -1291.42661\tMAPE: 21547297115479.1835937500\tMAE: 0.0253845338\n",
659 | "18:\tELBO: -1291.54401\tMAPE: 21404747194412.9375000000\tMAE: 0.0254279655\n",
660 | "19:\tELBO: -1291.61491\tMAPE: 21148752710847.1953125000\tMAE: 0.0253954081\n",
661 | "20:\tELBO: -1291.62401\tMAPE: 21201071288170.5078125000\tMAE: 0.0254324021\n",
662 | "21:\tELBO: -1291.66410\tMAPE: 20913343775791.9648437500\tMAE: 0.0253768308\n",
663 | "22:\tELBO: -1291.68854\tMAPE: 20802906367152.6718750000\tMAE: 0.0254136723\n",
664 | "23:\tELBO: -1291.70225\tMAPE: 20724522881699.0585937500\tMAE: 0.0254106758\n",
665 | "24:\tELBO: -1291.71356\tMAPE: 20600829292604.7539062500\tMAE: 0.0254094238\n",
666 | "25:\tELBO: -1291.72800\tMAPE: 20589589164613.6250000000\tMAE: 0.0253993871\n",
667 | "26:\tELBO: -1291.74384\tMAPE: 20555784997867.4531250000\tMAE: 0.0253696391\n",
668 | "27:\tELBO: -1291.76168\tMAPE: 20505356774216.2734375000\tMAE: 0.0253199341\n",
669 | "28:\tELBO: -1291.79195\tMAPE: 20402786912127.2148437500\tMAE: 0.0252806837\n",
670 | "29:\tELBO: -1291.80976\tMAPE: 20633077944323.2851562500\tMAE: 0.0253801531\n",
671 | "30:\tELBO: -1291.81923\tMAPE: 20553624559267.4414062500\tMAE: 0.0254227678\n",
672 | "31:\tELBO: -1291.82800\tMAPE: 20453027923451.8945312500\tMAE: 0.0253647713\n",
673 | "32:\tELBO: -1291.82885\tMAPE: 20419613464466.0039062500\tMAE: 0.0253471065\n",
674 | "33:\tELBO: -1291.82897\tMAPE: 20418544846720.1835937500\tMAE: 0.0253467103\n",
675 | "34:\tELBO: -1291.82899\tMAPE: 20418112853303.3320312500\tMAE: 0.0253468142\n",
676 | "35:\tELBO: -1291.82905\tMAPE: 20400968334069.3789062500\tMAE: 0.0253404985\n",
677 | "36:\tELBO: -1291.82909\tMAPE: 20376327509056.4843750000\tMAE: 0.0253342911\n",
678 | "37:\tELBO: -1291.82911\tMAPE: 20348522004733.7851562500\tMAE: 0.0253275435\n",
679 | "38:\tELBO: -1291.82915\tMAPE: 20324334954787.7460937500\tMAE: 0.0253245953\n",
680 | "39:\tELBO: -1291.82919\tMAPE: 20291759115919.2031250000\tMAE: 0.0253230582\n",
681 | "40:\tELBO: -1291.82921\tMAPE: 20275319847090.7968750000\tMAE: 0.0253242757\n",
682 | "41:\tELBO: -1291.82921\tMAPE: 20272793632065.7734375000\tMAE: 0.0253250873\n",
683 | "42:\tELBO: -1291.82921\tMAPE: 20275423124907.7460937500\tMAE: 0.0253251100\n",
684 | "0:\tELBO: -518.70083\tMAPE: 12640138266988.9550781250\tMAE: 0.0329245867\n",
685 | "1:\tELBO: -644.91534\tMAPE: 72378201152983.8437500000\tMAE: 0.0431798797\n",
686 | "2:\tELBO: -773.59294\tMAPE: 71485533819127.7187500000\tMAE: 0.0522142577\n",
687 | "3:\tELBO: -925.35972\tMAPE: 51438873145054.7187500000\tMAE: 0.0505348459\n",
688 | "4:\tELBO: -1209.72388\tMAPE: 25630998696704.3867187500\tMAE: 0.0232826876\n",
689 | "5:\tELBO: -1229.09071\tMAPE: 24120093760417.4882812500\tMAE: 0.0249540982\n",
690 | "6:\tELBO: -1250.20315\tMAPE: 23736091971770.0390625000\tMAE: 0.0235600580\n",
691 | "7:\tELBO: -1254.91801\tMAPE: 22031896792355.8281250000\tMAE: 0.0231435012\n",
692 | "8:\tELBO: -1258.95030\tMAPE: 21128385446362.6132812500\tMAE: 0.0218452010\n",
693 | "9:\tELBO: -1267.20974\tMAPE: 17619185326120.9648437500\tMAE: 0.0200846176\n",
694 | "10:\tELBO: -1275.87349\tMAPE: 14009581526665.8769531250\tMAE: 0.0188532632\n",
695 | "11:\tELBO: -1284.22624\tMAPE: 13623340680891.2500000000\tMAE: 0.0195441853\n",
696 | "12:\tELBO: -1287.68882\tMAPE: 13962111245522.2363281250\tMAE: 0.0204568208\n",
697 | "13:\tELBO: -1288.05225\tMAPE: 13779950612443.1132812500\tMAE: 0.0202815434\n",
698 | "14:\tELBO: -1288.62632\tMAPE: 14365161341801.4707031250\tMAE: 0.0207865462\n",
699 | "15:\tELBO: -1289.19739\tMAPE: 14633498859659.1230468750\tMAE: 0.0210248887\n",
700 | "16:\tELBO: -1289.52069\tMAPE: 14092612942579.2890625000\tMAE: 0.0211173260\n",
701 | "17:\tELBO: -1290.03879\tMAPE: 13801933987812.0390625000\tMAE: 0.0206641792\n",
702 | "18:\tELBO: -1290.16760\tMAPE: 13290044942357.8632812500\tMAE: 0.0203791469\n",
703 | "19:\tELBO: -1290.23337\tMAPE: 12748134807936.6074218750\tMAE: 0.0202005950\n",
704 | "20:\tELBO: -1290.23775\tMAPE: 12247109000225.9257812500\tMAE: 0.0200879649\n",
705 | "21:\tELBO: -1290.28437\tMAPE: 12363286338668.9394531250\tMAE: 0.0201714460\n",
706 | "22:\tELBO: -1290.29538\tMAPE: 12351073774798.8535156250\tMAE: 0.0202223645\n",
707 | "23:\tELBO: -1290.30312\tMAPE: 12297003809169.9199218750\tMAE: 0.0202626254\n",
708 | "24:\tELBO: -1290.31552\tMAPE: 12319553555632.9140625000\tMAE: 0.0202842205\n",
709 | "25:\tELBO: -1290.35182\tMAPE: 12649195622623.8671875000\tMAE: 0.0204453022\n",
710 | "26:\tELBO: -1290.39222\tMAPE: 12908211477452.8984375000\tMAE: 0.0204750693\n",
711 | "27:\tELBO: -1290.42443\tMAPE: 13179198280345.8457031250\tMAE: 0.0204980342\n",
712 | "28:\tELBO: -1290.43947\tMAPE: 13332886561229.6386718750\tMAE: 0.0205425863\n",
713 | "29:\tELBO: -1290.44995\tMAPE: 13329311279487.5566406250\tMAE: 0.0205595732\n",
714 | "30:\tELBO: -1290.45561\tMAPE: 13179505627532.2324218750\tMAE: 0.0206054121\n",
715 | "31:\tELBO: -1290.46224\tMAPE: 13047510318376.0332031250\tMAE: 0.0205803832\n",
716 | "32:\tELBO: -1290.46537\tMAPE: 12947857177494.5625000000\tMAE: 0.0205791208\n",
717 | "33:\tELBO: -1290.46669\tMAPE: 12860256844218.4199218750\tMAE: 0.0205846158\n",
718 | "34:\tELBO: -1290.46707\tMAPE: 12825879176747.0566406250\tMAE: 0.0205932140\n",
719 | "35:\tELBO: -1290.46800\tMAPE: 12756483685447.3164062500\tMAE: 0.0206013891\n",
720 | "36:\tELBO: -1290.46977\tMAPE: 12624324553345.8730468750\tMAE: 0.0206052735\n",
721 | "37:\tELBO: -1290.47259\tMAPE: 12264485492440.3300781250\tMAE: 0.0205518417\n",
722 | "38:\tELBO: -1290.47297\tMAPE: 12207938794877.0761718750\tMAE: 0.0205460684\n",
723 | "39:\tELBO: -1290.47369\tMAPE: 12110260250434.9199218750\tMAE: 0.0205136369\n",
724 | "40:\tELBO: -1290.47382\tMAPE: 12114724619568.5546875000\tMAE: 0.0205012369\n",
725 | "41:\tELBO: -1290.47387\tMAPE: 12133582674304.8105468750\tMAE: 0.0205032890\n",
726 | "42:\tELBO: -1290.47393\tMAPE: 12154391564985.3164062500\tMAE: 0.0205068143\n",
727 | "43:\tELBO: -1290.47403\tMAPE: 12170550383214.6328125000\tMAE: 0.0205120134\n",
728 | "44:\tELBO: -1290.47405\tMAPE: 12163216194653.5332031250\tMAE: 0.0205112301\n",
729 | "45:\tELBO: -1290.47409\tMAPE: 12152212241121.8808593750\tMAE: 0.0205119486\n",
730 | "46:\tELBO: -1290.47410\tMAPE: 12138552001941.0253906250\tMAE: 0.0205098486\n",
731 | "47:\tELBO: -1290.47410\tMAPE: 12136147824151.6816406250\tMAE: 0.0205101848\n"
732 | ]
733 | }
734 | ],
735 | "source": [
736 | "#kernel_name = \"stoch_heat_vector_pseudo_diff_3\"\n",
737 | "kernel_name = \"stoch_wave_kernel_nu_52\"\n",
738 | "\n",
739 | "for i, rs in enumerate(RANDOM_SEEDS):\n",
740 | " kernel = copy.deepcopy(kernels[kernel_name])\n",
741 | " utils.set_all_random_seeds(rs)\n",
742 | "\n",
743 | " utils.set_all_random_seeds(rs)\n",
744 | " train_X, train_y, test_X, test_y, qt = data_utils.generate_dataset(\n",
745 | " X, y, NUM_TRAIN, NUM_TEST,\n",
746 | " start=START + 10 * len(graph.nodes()) + i * len(graph.nodes()), log_target=False, rs=rs,\n",
747 | " interpolation=False)\n",
748 | " start = time.time()\n",
749 | " result, gprocess = utils_opt.evaluate_kernel_mcmc(\n",
750 | " kernel, train_X, train_y, test_X, test_y, graph,\n",
751 | " n_iter=N_ITER, dump_everything=False, dump_directory=DUMP_DIRECTORY)\n",
752 | " results[kernel_name][rs] = result"
753 | ]
754 | },
755 | {
756 | "cell_type": "code",
757 | "execution_count": 8,
758 | "metadata": {},
759 | "outputs": [],
760 | "source": [
761 | "agg_kernel_results = utils_postproc.parse_results(results)"
762 | ]
763 | },
764 | {
765 | "cell_type": "code",
766 | "execution_count": 9,
767 | "metadata": {},
768 | "outputs": [
769 | {
770 | "name": "stdout",
771 | "output_type": "stream",
772 | "text": [
773 | "stoch_wave_kernel_nu_52\n",
774 | "Mean: 0.0189 $\\pm$ 0.0028\n",
775 | "Confidence Interval 95%%: 0.01610405905604751, 0.021731816631653467\n",
776 | "(0.01610405905604751, 0.021731816631653467)\n",
777 | "Data: [0.025276921027308018, 0.019689531058666873, 0.01286791425585474, 0.015034357473585053, 0.01436047199927747, 0.015805523963914877, 0.01709513207841518, 0.016351642022197983, 0.019562474216397095, 0.02513599117882034, 0.02532511004077812, 0.02051018481099016]\n",
778 | "====================================\n"
779 | ]
780 | }
781 | ],
782 | "source": [
783 | "for kernel_name in results.keys():\n",
784 | " print(kernel_name)\n",
785 | " utils_postproc.stats_array([r[\"MAE\"] for r in results[kernel_name].values()])\n",
786 | " print(\"====================================\")"
787 | ]
788 | },
789 | {
790 | "cell_type": "markdown",
791 | "metadata": {},
792 | "source": [
793 | "## Interpolation (2)\n",
794 | "stoch_wave_kernel_nu_52\n",
795 | "Mean: 0.0028 $\\pm$ 0.0006\n",
796 | "Confidence Interval 95%%: 0.0022401338710728927, 0.003451229086341565\n",
797 | "(0.0022401338710728927, 0.003451229086341565)\n",
798 | "\n",
799 | "Data: [0.0017450406668729596, 0.004414391742788683, 0.004453214111611686, 0.0021427154657830167, 0.001821253925833599, 0.0024079721248159744, 0.0025783857412016963, 0.0022940820372391252, 0.0031607236224262952, 0.0022596849189193305, 0.003928416917940637, 0.0029422964690537417]\n",
800 | "\n",
801 | "====================================\n",
802 | "\n",
803 | "stoch_heat_vector_pseudo_diff_3\n",
804 | "Mean: 0.0047 $\\pm$ 0.0014\n",
805 | "Confidence Interval 95%%: 0.003322288909083467, 0.0060250689146361955\n",
806 | "(0.0033222889090834666, 0.006025068914636196)\n",
807 | "\n",
808 | "Data: [0.0039073701859408975, 0.00502979652575883, 0.007811263104709524, 0.002084047933587993, 0.004071909169771604, 0.0025462248547979053, 0.004087074038491602, 0.003633725543103668, 0.006207459554040093, 0.0028253794905910993, 0.009191976295210859, 0.004687920246313\n",
809 | "\n",
810 | "\n",
811 | "## Extrapolation \n",
812 | "stoch_heat_vector_pseudo_diff_3\n",
813 | "Mean: 0.0274 $\\pm$ 0.0067\n",
814 | "Confidence Interval 95%%: 0.020609731103628355, 0.03410758679848596\n",
815 | "(0.020609731103628355, 0.03410758679848596)\n",
816 | "\n",
817 | "Data: [0.03236539513724462, 0.022677598991751556, 0.01165332385894794, 0.007914633900852069, 0.01911465579315055, 0.028778972004901007, 0.03587617517178525, 0.039150246636727705, 0.039725863943984634, 0.03717489520129039, 0.03130447329271192, 0.022567673479338222]\n",
818 | "\n",
819 | "====================================\n",
820 | "\n",
821 | "stoch_wave_kernel_nu_52\n",
822 | "Mean: 0.0189 $\\pm$ 0.0028\n",
823 | "Confidence Interval 95%%: 0.01610405905604751, 0.021731816631653467\n",
824 | "(0.01610405905604751, 0.021731816631653467)\n",
825 | "\n",
826 | "Data: [0.025276921027308018, 0.019689531058666873, 0.01286791425585474, 0.015034357473585053, 0.01436047199927747, 0.015805523963914877, 0.01709513207841518, 0.016351642022197983, 0.019562474216397095, 0.02513599117882034, 0.02532511004077812, 0.02051018481099016]\n",
827 | "\n",
828 | "====================================\n",
829 | "\n"
830 | ]
831 | },
832 | {
833 | "cell_type": "code",
834 | "execution_count": 10,
835 | "metadata": {},
836 | "outputs": [
837 | {
838 | "name": "stdout",
839 | "output_type": "stream",
840 | "text": [
841 | "0:\tELBO: -474.96614\tMAPE: 24645810085517.2578125000\tMAE: 0.0296326765\n",
842 | "1:\tELBO: -576.02563\tMAPE: 19003437695932.1757812500\tMAE: 0.0259727013\n",
843 | "2:\tELBO: -713.54835\tMAPE: 12498801157852.0820312500\tMAE: 0.0201895766\n",
844 | "3:\tELBO: -822.26806\tMAPE: 13805302588245.2500000000\tMAE: 0.0130924225\n",
845 | "4:\tELBO: -1022.18304\tMAPE: 7525847832508.3515625000\tMAE: 0.0059436819\n",
846 | "5:\tELBO: -1048.62851\tMAPE: 6051768189694.6757812500\tMAE: 0.0051288084\n",
847 | "6:\tELBO: -1076.66693\tMAPE: 4967651560894.9482421875\tMAE: 0.0042977134\n",
848 | "7:\tELBO: -1087.22058\tMAPE: 3988782965217.1567382812\tMAE: 0.0040269194\n",
849 | "8:\tELBO: -1094.16457\tMAPE: 4297801887967.3291015625\tMAE: 0.0039813612\n",
850 | "9:\tELBO: -1098.30482\tMAPE: 4165111177225.1972656250\tMAE: 0.0038234018\n",
851 | "10:\tELBO: -1107.26751\tMAPE: 3840882090748.3291015625\tMAE: 0.0034909915\n",
852 | "11:\tELBO: -1113.32152\tMAPE: 3623277855807.8310546875\tMAE: 0.0033348642\n",
853 | "12:\tELBO: -1120.58866\tMAPE: 3591378790111.0405273438\tMAE: 0.0032380606\n",
854 | "13:\tELBO: -1124.32423\tMAPE: 3654870613543.1000976562\tMAE: 0.0032316609\n",
855 | "14:\tELBO: -1128.08389\tMAPE: 3345362973381.7099609375\tMAE: 0.0030733673\n",
856 | "15:\tELBO: -1130.62344\tMAPE: 3197316655935.5991210938\tMAE: 0.0029438184\n",
857 | "16:\tELBO: -1131.27964\tMAPE: 2958687740761.7348632812\tMAE: 0.0028542106\n",
858 | "17:\tELBO: -1131.36479\tMAPE: 2970904410209.6284179688\tMAE: 0.0028589393\n",
859 | "18:\tELBO: -1131.48152\tMAPE: 2979031013072.9394531250\tMAE: 0.0028658297\n",
860 | "19:\tELBO: -1131.64123\tMAPE: 2959802652426.6679687500\tMAE: 0.0028606358\n",
861 | "20:\tELBO: -1131.75535\tMAPE: 3073073396650.7685546875\tMAE: 0.0028898361\n",
862 | "21:\tELBO: -1132.08893\tMAPE: 3011894854709.0009765625\tMAE: 0.0028749672\n",
863 | "22:\tELBO: -1132.51635\tMAPE: 2959290561555.9931640625\tMAE: 0.0028619974\n",
864 | "23:\tELBO: -1132.75207\tMAPE: 2967361042712.6123046875\tMAE: 0.0028700259\n",
865 | "24:\tELBO: -1132.76807\tMAPE: 2971788132412.2812500000\tMAE: 0.0028652019\n",
866 | "25:\tELBO: -1132.89897\tMAPE: 2999267443088.4072265625\tMAE: 0.0028788519\n",
867 | "26:\tELBO: -1132.95688\tMAPE: 3015906129854.9882812500\tMAE: 0.0028875495\n",
868 | "27:\tELBO: -1132.97280\tMAPE: 3032049941358.5253906250\tMAE: 0.0028907912\n",
869 | "28:\tELBO: -1133.00444\tMAPE: 3017220500193.0249023438\tMAE: 0.0028850949\n",
870 | "29:\tELBO: -1133.02348\tMAPE: 3012677241533.1938476562\tMAE: 0.0028815979\n",
871 | "30:\tELBO: -1133.03020\tMAPE: 2997158793964.4775390625\tMAE: 0.0028736042\n",
872 | "31:\tELBO: -1133.03664\tMAPE: 3002187990900.9667968750\tMAE: 0.0028734929\n",
873 | "32:\tELBO: -1133.03756\tMAPE: 3003889222497.9687500000\tMAE: 0.0028734281\n",
874 | "33:\tELBO: -1133.04032\tMAPE: 3004968020183.5976562500\tMAE: 0.0028721141\n",
875 | "34:\tELBO: -1133.04306\tMAPE: 3004327370774.7822265625\tMAE: 0.0028695747\n",
876 | "35:\tELBO: -1133.04705\tMAPE: 2997328580704.2617187500\tMAE: 0.0028673854\n",
877 | "36:\tELBO: -1133.05083\tMAPE: 2996620623659.6010742188\tMAE: 0.0028663415\n",
878 | "37:\tELBO: -1133.05592\tMAPE: 2990269737402.7709960938\tMAE: 0.0028667763\n",
879 | "38:\tELBO: -1133.06188\tMAPE: 2990675070556.4233398438\tMAE: 0.0028690704\n",
880 | "39:\tELBO: -1133.06948\tMAPE: 2988114790390.9287109375\tMAE: 0.0028703060\n",
881 | "40:\tELBO: -1133.07193\tMAPE: 3007199425001.3310546875\tMAE: 0.0028773070\n",
882 | "41:\tELBO: -1133.07755\tMAPE: 2998259084262.6542968750\tMAE: 0.0028723819\n",
883 | "42:\tELBO: -1133.07892\tMAPE: 2995706175100.3784179688\tMAE: 0.0028700472\n",
884 | "43:\tELBO: -1133.07947\tMAPE: 2995147920854.4750976562\tMAE: 0.0028692143\n",
885 | "44:\tELBO: -1133.07960\tMAPE: 3001760103112.0590820312\tMAE: 0.0028715120\n",
886 | "45:\tELBO: -1133.07983\tMAPE: 2999767305004.9653320312\tMAE: 0.0028713374\n",
887 | "46:\tELBO: -1133.07985\tMAPE: 2999859591215.7187500000\tMAE: 0.0028713858\n",
888 | "47:\tELBO: -1133.07986\tMAPE: 3000046286209.7587890625\tMAE: 0.0028715826\n",
889 | "48:\tELBO: -1133.07986\tMAPE: 3000302612850.5224609375\tMAE: 0.0028717173\n",
890 | "49:\tELBO: -1133.07986\tMAPE: 3000170185181.0400390625\tMAE: 0.0028717180\n"
891 | ]
892 | }
893 | ],
894 | "source": [
895 | "kernel = copy.deepcopy(kernels[\"stoch_wave_kernel_nu_52\"])\n",
896 | "#kernel = copy.deepcopy(kernels[\"stoch_heat_vector_pseudo_diff_3\"])\n",
897 | "utils.set_all_random_seeds(rs)\n",
898 | "\n",
899 | "utils.set_all_random_seeds(rs)\n",
900 | "train_X, train_y, test_X, test_y, qt = data_utils.generate_dataset(\n",
901 | " X, y, NUM_TRAIN, NUM_TEST,\n",
902 | " start=START + 10 * len(graph.nodes()), log_target=False, rs=rs,\n",
903 | " interpolation=True)\n",
904 | "start = time.time()\n",
905 | "result, gprocess = utils_opt.evaluate_kernel_mcmc(\n",
906 | " kernel, train_X, train_y, test_X, test_y, graph,\n",
907 | " n_iter=N_ITER, dump_everything=False, dump_directory=DUMP_DIRECTORY)\n",
908 | "results[kernel_name][rs] = result"
909 | ]
910 | },
911 | {
912 | "cell_type": "code",
913 | "execution_count": 14,
914 | "metadata": {},
915 | "outputs": [],
916 | "source": [
917 | "def filter_ds(X, y, node_id):\n",
918 | " inds = [i for i in range(X.shape[0]) if X[i, 0] == node_id]\n",
919 | " return X[inds], np.array(y)[inds][:]\n",
920 | "\n",
921 | "\n",
922 | "def plot(m, X_train, signal, node_id):\n",
923 | " xmin, xmax = 0.5, 5\n",
924 | " xx = np.linspace(xmin, xmax, 100)\n",
925 | " xx = np.array([[node_id, x] for x in xx])\n",
926 | " \n",
927 | " mean, var = m.predict_y(xx)\n",
928 | " plt.figure(figsize=(12, 6))\n",
929 | " X_train, signal = filter_ds(X_train, signal, node_id)\n",
930 | " plt.plot(X_train[:, 1], signal, 'kx', mew=2)\n",
931 | " plt.plot(xx[:, 1], mean, 'b', lw=2)\n",
932 | " plt.fill_between(xx[:, 1], mean[:, 0] - 2 * np.sqrt(var[:, 0]), mean[:, 0] + 2 * np.sqrt(var[:, 0]), color='blue', alpha=0.2)\n",
933 | " plt.xlim(xmin, xmax)\n",
934 | " plt.title(\"Fit of GP (SWEK) model to synthetic wave dataset (node: {})\".format(node_id), fontsize=24)\n",
935 | " plt.savefig(\"images/swek_fit_wave_{}.pdf\".format(node_id))"
936 | ]
937 | },
938 | {
939 | "cell_type": "code",
940 | "execution_count": 15,
941 | "metadata": {},
942 | "outputs": [
943 | {
944 | "data": {
945 | "image/png": "\n",
946 | "text/plain": [
947 | ""
948 | ]
949 | },
950 | "metadata": {
951 | "needs_background": "light"
952 | },
953 | "output_type": "display_data"
954 | }
955 | ],
956 | "source": [
957 | "plot(gprocess, train_X, train_y, 2)"
958 | ]
959 | },
960 | {
961 | "cell_type": "code",
962 | "execution_count": null,
963 | "metadata": {},
964 | "outputs": [],
965 | "source": []
966 | }
967 | ],
968 | "metadata": {
969 | "kernelspec": {
970 | "display_name": "Python 3",
971 | "language": "python",
972 | "name": "python3"
973 | },
974 | "language_info": {
975 | "codemirror_mode": {
976 | "name": "ipython",
977 | "version": 3
978 | },
979 | "file_extension": ".py",
980 | "mimetype": "text/x-python",
981 | "name": "python",
982 | "nbconvert_exporter": "python",
983 | "pygments_lexer": "ipython3",
984 | "version": "3.6.5"
985 | }
986 | },
987 | "nbformat": 4,
988 | "nbformat_minor": 4
989 | }
990 |
--------------------------------------------------------------------------------
/experiments/run_chicken_pox.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | import json
4 | import time
5 | import argparse
6 | import copy
7 | import sklearn
8 |
9 | import networkx as nx
10 | from tqdm import tqdm
11 |
12 | import gpflow
13 | import tensorflow as tf
14 |
15 | from graph_kernels import utils
16 | from graph_kernels import data_utils
17 | from graph_kernels import utils_opt
18 | from graph_kernels import time_kernels
19 |
20 | gpflow.config.set_default_jitter(1e-4)
21 | gpflow.config.set_default_float(tf.float64)
22 | f64 = gpflow.utilities.to_default_float
23 |
24 |
25 | def parse_arguments():
26 | parser = argparse.ArgumentParser(description='Toy epidemiological dataset.')
27 | parser.add_argument('--dump_directory', type=str,
28 | help='Path to directory with results.',
29 | default="dump_directory")
30 |
31 | parser.add_argument('--num_test_weeks', type=int, help='Number of test weeks.', default=2)
32 |
33 | group = parser.add_mutually_exclusive_group()
34 | group.add_argument('--interpolation', action='store_true', default=False,
35 | help='Evaluate the models on the interpolation task.')
36 | group.add_argument('--extrapolation', dest='interpolation', action='store_false')
37 |
38 | return parser.parse_args()
39 |
40 |
41 | args = parse_arguments()
42 |
43 | INTERPOLATION = args.interpolation
44 | DATA_FOLDER = os.path.join(
45 | os.path.dirname(os.path.abspath(__file__)), "../data/hungary_chicken_pox/")
46 |
47 | GRAPH_PATH = os.path.join(DATA_FOLDER, "hungary_county_edges.csv")
48 | graph = data_utils.load_hungary_graph(GRAPH_PATH)
49 | graph.remove_edges_from(nx.selfloop_edges(graph))
50 |
51 | NUM_TEST_WEEKS = args.num_test_weeks
52 | NUM_TRAIN = 103 * len(graph.nodes())
53 | NUM_TEST = len(graph.nodes()) * NUM_TEST_WEEKS
54 | N_ITER = 2_000
55 |
56 | RANDOM_SEEDS = [23, 42, 82, 100, 123, 223,
57 | 2 * 23, 2 * 42, 2 * 82, 2 * 100, 2 * 123, 2 * 223]
58 |
59 | DUMP_DIRECTORY = args.dump_directory
60 | DUMP_EVERYTHING = False
61 | os.makedirs(DUMP_DIRECTORY, exist_ok=True)
62 |
63 |
64 | X, y, graph, _ = data_utils.load_hungary_dataset(
65 | graph, path_to_csv=os.path.join(DATA_FOLDER, "hungary_chickenpox.csv"))
66 |
67 |
68 | exp_kernels = {
69 | #"td_laplacian": time_kernels.TimeDistributedLaplacianKernel(graph),
70 | "stoch_heat_vector_pseudo_diff_1_scalar": time_kernels.StochasticHeatEquation(
71 | graph, c=0.1, use_pseudodifferential=True, nu=1 / 2,
72 | kappa=1, variance=1.),
73 | "td_matern_nu_52_d_1": time_kernels.TimeDistributedMaternKernel(graph, nu=5 / 2, kappa=1),
74 | "td_matern_nu_32_d_1": time_kernels.TimeDistributedMaternKernel(graph, nu=3 / 2, kappa=1),
75 | "td_matern_nu_12_d_1": time_kernels.TimeDistributedMaternKernel(graph, nu=1 / 2, kappa=1),
76 | "stoch_heat_vector_pseudo_diff_1": time_kernels.StochasticHeatEquation(
77 | graph, c=0.1, use_pseudodifferential=True, nu=5 / 2,
78 | kappa=1, variance=[1.] * len(graph.nodes())),
79 | "stoch_heat_vector_pseudo_diff_2": time_kernels.StochasticHeatEquation(
80 | graph, c=0.1, use_pseudodifferential=True, nu=3 / 2,
81 | kappa=1, variance=[1.] * len(graph.nodes())),
82 | "stoch_heat_vector_pseudo_diff_3": time_kernels.StochasticHeatEquation(
83 | graph, c=0.1, use_pseudodifferential=True, nu=1 / 2,
84 | kappa=1, variance=[1.] * len(graph.nodes())),
85 | }
86 |
87 |
88 | if __name__ == "__main__":
89 | results = {}
90 | for kernel_name, kernel in exp_kernels.items():
91 | print("Evaluating kernel ", kernel_name)
92 | results[kernel_name] = {}
93 | for i, rs in tqdm(enumerate(RANDOM_SEEDS), total=len(RANDOM_SEEDS)):
94 | utils.set_all_random_seeds(rs)
95 | train_X, train_y, test_X, test_y, qt = data_utils.generate_dataset(
96 | X, y.ravel(),
97 | num_training_data=NUM_TRAIN, num_testing_data=NUM_TEST,
98 | start=i * len(graph.nodes()),
99 | log_target=True, rs=rs)
100 | if INTERPOLATION:
101 | train_X, test_X, train_y, test_y = \
102 | sklearn.model_selection.train_test_split(
103 | train_X, train_y, test_size=0.1, random_state=rs)
104 | train_y = tf.cast(train_y, tf.float64)
105 | test_y = tf.cast(test_y, tf.float64)
106 |
107 | start = time.time()
108 | result, gprocess = utils_opt.evaluate_kernel_mcmc(
109 | copy.deepcopy(kernel), train_X, train_y, test_X, test_y, graph,
110 | transformer=qt,
111 | n_iter=N_ITER, optimizer_name="LBFGS")
112 | results[kernel_name][rs] = result
113 | results[kernel_name][rs]["time"] = time.time() - start
114 | json.dump(results, open(os.path.join(DUMP_DIRECTORY, "results.json"), "w"))
115 |
--------------------------------------------------------------------------------
/experiments/run_covid_experiments.py:
--------------------------------------------------------------------------------
1 | import os
2 | import pickle
3 | import json
4 | import time
5 | import argparse
6 | import copy
7 |
8 | from tqdm import tqdm
9 |
10 | import gpflow
11 |
12 | from graph_kernels import utils
13 | from graph_kernels import time_kernels
14 | from graph_kernels import data_utils
15 | from graph_kernels import utils_opt
16 |
17 |
18 | def parse_arguments():
19 | parser = argparse.ArgumentParser(description='COVID-19 across the US')
20 | parser.add_argument('--log_target', action='store_true', default=False,
21 | help='Apply log transform to the target.')
22 | parser.add_argument('--no-log_target', dest='log_target', action='store_false')
23 |
24 | parser.add_argument('--use_flight_graph', action='store_true', default=False,
25 | help='Use graph that contains information about the flights.')
26 | parser.add_argument('--no-use_flight_graph', dest='use_flight_graph', action='store_false')
27 |
28 | group = parser.add_mutually_exclusive_group()
29 | group.add_argument('--use_normalized_target', action='store_true', default=False,
30 | help='Normalize the target by population in a state.')
31 | group.add_argument('--no-use_normalized_target', dest='use_normalized_target', action='store_false')
32 |
33 | group = parser.add_mutually_exclusive_group()
34 | group.add_argument('--interpolation', action='store_true', default=False,
35 | help='Evaluate the models on the interpolation task.')
36 | group.add_argument('--extrapolation', dest='interpolation', action='store_false')
37 |
38 | parser.add_argument('--dump_directory', type=str, help='Path to directory with results.',
39 | default="dump_directory")
40 |
41 | parser.add_argument('--num_test_weeks', type=int, help='Number of test weeks.', default=2)
42 |
43 | return parser.parse_args()
44 |
45 |
46 | args = parse_arguments()
47 |
48 | gpflow.config.set_default_jitter(1e-4)
49 |
50 | DATA_FOLDER = os.path.join(
51 | os.path.dirname(os.path.abspath(__file__)), "../data/covid_data/")
52 |
53 | INTERPOLATION = args.interpolation
54 | USE_FLIGHT_GRAPH = args.use_flight_graph
55 | if USE_FLIGHT_GRAPH:
56 | GRAPH_PATH = os.path.join(DATA_FOLDER, "state_graph.pkl")
57 | else:
58 | GRAPH_PATH = os.path.join(DATA_FOLDER, "g.pkl")
59 | graph = pickle.load(open(GRAPH_PATH, "rb"))
60 | N_NODES = len(graph.nodes())
61 |
62 | NUM_TEST_WEEKS = args.num_test_weeks
63 | NUM_TRAIN = 33 * N_NODES
64 | NUM_TEST = NUM_TEST_WEEKS * N_NODES
65 | START = 4 * N_NODES * 2
66 | N_ITER = 5_000
67 |
68 | IS_PREDICT_CASES = True
69 | LOG_TARGET = args.log_target
70 | USE_NORMALIZED_TARGET = args.use_normalized_target
71 |
72 |
73 | RANDOM_SEEDS = [23, 42, 82, 100, 123,
74 | 2 * 23, 2 * 42, 2 * 82, 2 * 100, 2 * 123]
75 |
76 | X_PATH = os.path.join(DATA_FOLDER, "X.pkl")
77 | if USE_NORMALIZED_TARGET:
78 | Y_CASES_PATH = os.path.join(DATA_FOLDER, "y_cases_normalized.pkl")
79 | Y_DEATHS_PATH = os.path.join(DATA_FOLDER, "y_deaths_normalized.pkl")
80 | else:
81 | Y_CASES_PATH = os.path.join(DATA_FOLDER, "y_cases.pkl")
82 | Y_DEATHS_PATH = os.path.join(DATA_FOLDER, "y_deaths.pkl")
83 |
84 |
85 | FROM_STATE_TO_ID_PATH = os.path.join(DATA_FOLDER, "from_state_to_id.pkl")
86 | DUMP_DIRECTORY = args.dump_directory
87 | DUMP_EVERYTHING = False
88 | os.makedirs(DUMP_DIRECTORY, exist_ok=True)
89 |
90 |
91 | X = pickle.load(open(X_PATH, "rb"))
92 | y_cases = pickle.load(open(Y_CASES_PATH, "rb"))
93 | y_deaths = pickle.load(open(Y_DEATHS_PATH, "rb"))
94 | from_state_to_id = pickle.load(open(FROM_STATE_TO_ID_PATH, "rb"))
95 |
96 |
97 | if IS_PREDICT_CASES:
98 | y = y_cases
99 | else:
100 | y = y_deaths
101 |
102 |
103 | y[y < 0] = 0
104 |
105 | exp_kernels = {
106 | "td_matern_nu_52_k_1": time_kernels.TimeDistributedMaternKernel(graph, nu=5 / 2, kappa=1),
107 | "td_matern_nu_32_k_1": time_kernels.TimeDistributedMaternKernel(graph, nu=3 / 2, kappa=1),
108 | "td_matern_nu_12_k_1": time_kernels.TimeDistributedMaternKernel(graph, nu=1 / 2, kappa=1),
109 | "stoch_heat_vector_pseudo_diff_1": time_kernels.StochasticHeatEquation(
110 | graph, c=0.1, use_pseudodifferential=True, nu=5 / 2, kappa=1, variance=[1.] * len(graph.nodes())),
111 | "stoch_heat_vector_pseudo_diff_2": time_kernels.StochasticHeatEquation(
112 | graph, c=0.1, use_pseudodifferential=True, nu=3 / 2, kappa=1, variance=[1.] * len(graph.nodes())),
113 | "stoch_heat_vector_pseudo_diff_3": time_kernels.StochasticHeatEquation(
114 | graph, c=0.1, use_pseudodifferential=True, nu=1 / 2, kappa=1, variance=[1.] * len(graph.nodes())),
115 | }
116 |
117 |
118 | if __name__ == "__main__":
119 | results = {}
120 | for kernel_name, kernel in exp_kernels.items():
121 | results[kernel_name] = {}
122 | for i, rs in tqdm(enumerate(RANDOM_SEEDS), total=len(RANDOM_SEEDS)):
123 | utils.set_all_random_seeds(rs)
124 | train_X, train_y, test_X, test_y, qt = data_utils.generate_dataset(
125 | X, y, NUM_TRAIN, NUM_TEST,
126 | start=START + i * len(graph.nodes()), log_target=LOG_TARGET, rs=rs,
127 | interpolation=INTERPOLATION)
128 | print("Evaluating kernel ", kernel_name)
129 | start = time.time()
130 | result, gprocess = utils_opt.evaluate_kernel_mcmc(
131 | copy.deepcopy(kernel), train_X, train_y, test_X, test_y, graph,
132 | transformer=qt,
133 | n_iter=N_ITER, dump_directory=DUMP_DIRECTORY,
134 | dump_everything=DUMP_EVERYTHING, optimizer_name="LBFGS")
135 | results[kernel_name][rs] = result
136 | results[kernel_name][rs]["time"] = time.time() - start
137 | json.dump(results, open(os.path.join(DUMP_DIRECTORY, "results.json"), "w"))
138 |
--------------------------------------------------------------------------------
/graph_kernels/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AaltoPML/spatiotemporal-graph-kernels/b82529b6a1dfeb4ff06fb9ea0952ec926f0d7ae0/graph_kernels/__init__.py
--------------------------------------------------------------------------------
/graph_kernels/data_utils.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | import numpy as np
4 | import pandas as pd
5 | import networkx as nx
6 | import pickle
7 |
8 | import sklearn
9 | from sklearn.preprocessing import FunctionTransformer
10 |
11 | import matplotlib.pyplot as plt
12 |
13 |
14 | HEAT_DATASET_1d = "./data/heat_distribution/1d.pkl"
15 | HEAT_DATASET_2d = "./data/heat_distribution/2d.pkl"
16 |
17 |
18 | def same_component(v1, v2, N):
19 | return v1 < N / 2 and v2 < N / 2 or v1 > N / 2 and v2 > N / 2
20 |
21 |
22 | def get_noisy_signal(N, variance=0.1):
23 | signal = []
24 | for i in range(N // 2):
25 | signal.append(np.random.normal(-1, variance))
26 |
27 | signal.append(0)
28 |
29 | for i in range(N // 2):
30 | signal.append(np.random.normal(1, variance))
31 | return np.array(signal)
32 |
33 |
34 | def generate_graph_n_comp(N, n_comp=3):
35 | raise NotImplementedError
36 |
37 |
38 | def generate_graph(N, p):
39 | covariances = np.zeros((N, N))
40 | G = nx.Graph()
41 | for v1 in range(0, N):
42 | G.add_node(v1)
43 | for v2 in range(0, N):
44 | if v1 == v2:
45 | covariances[v1, v2] = 1
46 | continue
47 | else:
48 | if same_component(v1, v2, N) and np.random.rand() < p or v1 == N // 2 or v2 == N // 2:
49 | G.add_edge(v1, v2)
50 | covariances[v1, v2] = p
51 | return G, get_noisy_signal(N), covariances
52 |
53 |
54 | def generate_ring_graph(N):
55 | graph = nx.Graph()
56 | graph.add_nodes_from(range(N))
57 | graph.add_edges_from([(i, (i + 1) % N) for i in range(N)])
58 | signal = [-1 + i * (2 / (N - 1)) for i in range(N)]
59 | return graph, np.array(signal)
60 |
61 |
62 | def generate_lattice(n):
63 | G = nx.Graph()
64 | for v1 in range(0, n - 1):
65 | G.add_node(v1)
66 | G.add_node(v1 + 1)
67 | G.add_edge(v1, v1 + 1)
68 | return G
69 |
70 |
71 | def generate_2d_lattice(n, m=None):
72 | if n is None:
73 | n = m
74 | return nx.generators.lattice.grid_2d_graph(m, n)
75 |
76 |
77 | def draw_2d_lattice(G, signal=None):
78 | pos = dict((n, n) for n in G.nodes())
79 | labels = dict(((i, j), i * 10 + j) for i, j in G.nodes())
80 | nx.draw_networkx(G, pos=pos, labels=labels, node_color=signal)
81 |
82 |
83 | def plot_nodes_with_colors(g, signal, title="2 component graph", layout=nx.spring_layout, ax=None):
84 | if layout is None:
85 | layout = nx.spring_layout
86 |
87 | nodes = g.nodes()
88 | assert len(nodes) == len(signal)
89 |
90 | # drawing nodes and edges separately so we can capture collection for colobar
91 | pos = layout(g)
92 | nx.draw_networkx_edges(g, pos, alpha=0.2)
93 | nc = nx.draw_networkx_nodes(g, pos, nodelist=nodes, node_color=signal,
94 | node_size=100, cmap=plt.cm.jet, ax=ax)
95 | plt.title(title)
96 | plt.colorbar(nc)
97 | plt.axis('off')
98 |
99 |
100 | def visualize_kernel_for_graph(gprocess, G, node=0, title=None):
101 | all_nodes = sorted(list(G.nodes()))
102 | covariance_train = gprocess.kernel.K(all_nodes)
103 | plot_nodes_with_colors(
104 | G,
105 | covariance_train[0],
106 | layout=nx.layout.circular_layout,
107 | title=title)
108 |
109 |
110 | def read_heat_1d(path=HEAT_DATASET_1d):
111 | return pickle.load(open(path, "rb"))
112 |
113 |
114 | def read_heat_2d(path=HEAT_DATASET_2d):
115 | return pickle.load(open(path, "rb"))
116 |
117 |
118 | def build_graph_from_1d_points(x_lin):
119 | G = nx.Graph()
120 | for i in range(x_lin.shape[0]):
121 | G.add_node(i, point=x_lin[i])
122 |
123 | for i in range(1, x_lin.shape[0]):
124 | G.add_node(i, point=x_lin[i])
125 | G.add_edge(i - 1, i)
126 | if i + 1 < x_lin.shape[0]:
127 | G.add_edge(i, i + 1)
128 |
129 | return G
130 |
131 |
132 | def build_graph_from_2d_points(X, Y):
133 | G = nx.Graph()
134 | for i in range(X.shape[0]):
135 | for j in range(X.shape[1]):
136 | G.add_node((i, j))
137 | if i - 1 >= 0:
138 | G.add_edge((i - 1, j), (i, j))
139 | if j - 1 >= 0:
140 | G.add_edge((i, j - 1), (i, j))
141 | return G
142 |
143 |
144 | def generate_dataset(X, y, num_training_data, num_testing_data, start=0, log_target=False, rs=42,
145 | interpolation=False):
146 | start_test = start + num_training_data
147 | end_test = start_test + num_testing_data
148 |
149 | if interpolation:
150 | train_X, test_X, train_y, test_y = sklearn.model_selection.train_test_split(
151 | X[start:end_test], y[start:end_test],
152 | test_size=0.1, random_state=rs)
153 | train_y, test_y = train_y[:, np.newaxis], test_y[:, np.newaxis]
154 | else:
155 | train_X, train_y = X[start:start_test], y[start:start_test, np.newaxis]
156 | test_X, test_y = X[start_test:end_test], y[start_test:end_test, np.newaxis]
157 |
158 | if log_target:
159 | qt = FunctionTransformer(func=np.log1p, inverse_func=np.expm1)
160 | train_y = qt.fit_transform(train_y)
161 | test_y = qt.transform(test_y)
162 | else:
163 | qt = None
164 |
165 | return train_X, train_y, test_X, test_y, qt
166 |
167 |
168 | def load_hungary_graph(path_to_csv="../data/hungary_chicken_pox/hungary_county_edges.csv"):
169 | df = pd.read_csv(path_to_csv)
170 | g = nx.from_pandas_edgelist(df, source="name_1", target="name_2")
171 | return g
172 |
173 |
174 | def load_hungary_dataset(g, path_to_csv="../data/hungary_chicken_pox/hungary_chickenpox.csv"):
175 | df = pd.read_csv(path_to_csv)
176 | from_node_to_id = dict(zip(g.nodes(), range(len(g.nodes()))))
177 | X, y = [], []
178 | for t, row_dict in enumerate(df.to_dict(orient="records")):
179 | for v in g.nodes():
180 | X.append([from_node_to_id[v], t])
181 | y.append(row_dict[v])
182 |
183 | g = nx.relabel_nodes(g, from_node_to_id)
184 | return np.array(X), np.array(y), g, from_node_to_id
185 |
186 |
187 | def generate_new_chickenpox_dataset(start=0):
188 | g = load_hungary_graph()
189 | X, y, g, node_ids = load_hungary_dataset(g)
190 | new_dataset = {}
191 | new_dataset["edges"] = list(g.edges())
192 | # new_dataset["node_ids"] = list(g.edges())
193 | new_dataset["FX"] = y[start * len(g.nodes()):].reshape((-1, len(g.nodes()))).tolist()
194 | return new_dataset
195 |
196 |
197 | def generate_new_covid19_dataset(start=0):
198 | DATA_FOLDER = "../data/covid_data/"
199 | GRAPH_PATH = os.path.join(DATA_FOLDER, "g.pkl")
200 |
201 | g = pickle.load(open(GRAPH_PATH, "rb"))
202 | g = nx.relabel.convert_node_labels_to_integers(g)
203 | X_PATH = os.path.join(DATA_FOLDER, "X.pkl")
204 | Y_CASES_PATH = os.path.join(DATA_FOLDER, "y_cases.pkl")
205 | X = pickle.load(open(X_PATH, "rb"))
206 | y_cases = pickle.load(open(Y_CASES_PATH, "rb"))
207 | y_cases[y_cases < 0] = 0
208 |
209 | new_dataset = {}
210 | new_dataset["edges"] = list(g.edges())
211 | y_cases_reordered = []
212 | for i in range(start * len(g.nodes()), y_cases.shape[0], len(g.nodes())):
213 | cur = y_cases[i:i + len(g.nodes())]
214 | new_row = np.zeros(len(g.nodes()))
215 | for j, x in enumerate(X[i:i + len(g.nodes())]):
216 | new_row[int(x[0])] = cur[j]
217 | y_cases_reordered.append(new_row)
218 | y_cases_reordered = np.array(y_cases_reordered)
219 | new_dataset["FX"] = y_cases_reordered.tolist()
220 | return new_dataset
221 |
--------------------------------------------------------------------------------
/graph_kernels/kernels.py:
--------------------------------------------------------------------------------
1 | import gpflow
2 | from gpflow import Parameter
3 | import tensorflow as tf
4 |
5 | import scipy.linalg
6 | import numpy as np
7 |
8 | from . import utils
9 |
10 |
11 | def get_matern_kernel(L, nu, kappa):
12 | N = L.shape[0]
13 | alpha = nu
14 | Id = np.eye(N)
15 | A = ((2 * nu / kappa**2) * Id + L)
16 | A = scipy.linalg.fractional_matrix_power(A, alpha / 2)
17 | A = tf.cast(A, dtype=tf.float64)
18 | kern = tf.matmul(A, A, adjoint_a=True)
19 | kern = tf.linalg.pinv(kern)
20 | return kern
21 |
22 |
23 | class LaplacianKernel(gpflow.kernels.base.Kernel):
24 | def __init__(self, sparse_adj_mat, variance=1.0, normalized_laplacian=True):
25 | super().__init__()
26 | self.variance = Parameter(variance, transform=gpflow.utilities.positive(), name="variance")
27 | self.sparse_adj_mat = sparse_adj_mat
28 | self.laplacian = utils.get_laplacian(sparse_adj_mat, normalized_laplacian)
29 | self.cov = tf.matmul(self.laplacian, self.laplacian, adjoint_a=True)
30 | self.cov = tf.linalg.pinv(self.cov)
31 |
32 | def K(self, X, Y=None, presliced=False):
33 | X = tf.reshape(tf.cast(X, tf.int32), [-1])
34 | X2 = tf.reshape(tf.cast(Y, tf.int32), [-1]) if Y is not None else X
35 |
36 | cov = self.variance * self.cov
37 | cov = tf.gather(tf.gather(cov, X, axis=0), X2, axis=1)
38 | return cov
39 |
40 | def K_diag(self, X, presliced=False):
41 | return tf.linalg.diag_part(self.K(X, presliced=presliced))
42 |
43 |
44 | class DiffusionKernel(gpflow.kernels.base.Kernel):
45 | def __init__(self, sparse_adj_mat, variance=1.0, beta=0.1, normalized_laplacian=True):
46 | super().__init__()
47 | self.variance = Parameter(variance, transform=gpflow.utilities.positive(), name="variance")
48 | self.beta = beta
49 |
50 | self.sparse_adj_mat = sparse_adj_mat
51 | self.laplacian = utils.get_laplacian(sparse_adj_mat, normalized_laplacian)
52 | self.cov = tf.linalg.expm(-self.beta * self.laplacian)
53 |
54 | def K(self, X, Y=None, presliced=False):
55 | X = tf.reshape(tf.cast(X, tf.int32), [-1])
56 | X2 = tf.reshape(tf.cast(Y, tf.int32), [-1]) if Y is not None else X
57 |
58 | cov = self.variance * self.cov
59 | cov = tf.gather(tf.gather(cov, X, axis=0), X2, axis=1)
60 |
61 | return cov
62 |
63 | def K_diag(self, X, presliced=False):
64 | return tf.linalg.diag_part(self.K(X, presliced=presliced))
65 |
66 |
67 | class RandomWalkKernel(gpflow.kernels.base.Kernel):
68 | def __init__(self, sparse_adj_mat, variance=1.0):
69 | super().__init__()
70 | self.variance = Parameter(variance, transform=gpflow.utilities.positive(), name="variance")
71 | sparse_adj_mat[np.diag_indices(sparse_adj_mat.shape[0])] = 1.0
72 | self.sparse_P = utils.sparse_mat_to_sparse_tensor(sparse_adj_mat)
73 | self.sparse_P = self.sparse_P / sparse_adj_mat.sum(axis=1)
74 | self.cov = tf.sparse.sparse_dense_matmul(self.sparse_P, tf.sparse.to_dense(self.sparse_P), adjoint_b=True)
75 |
76 | def K(self, X, Y=None, presliced=False):
77 | X = tf.reshape(tf.cast(X, tf.int32), [-1])
78 | X2 = tf.reshape(tf.cast(Y, tf.int32), [-1]) if Y is not None else X
79 | cov = self.variance * self.cov
80 | cov = tf.gather(tf.gather(cov, X, axis=0), X2, axis=1)
81 | return cov
82 |
83 | def K_diag(self, X, presliced=False):
84 | return tf.linalg.diag_part(self.K(X, presliced=presliced))
85 |
86 |
87 | class MaternKernel(gpflow.kernels.base.Kernel):
88 | def __init__(self, sparse_adj_mat, nu, kappa, variance=1.0, normalized_laplacian=True):
89 | super().__init__()
90 | self.variance = Parameter(variance, transform=gpflow.utilities.positive(), name="variance")
91 | self.nu = nu
92 | self.kappa = kappa
93 |
94 | self.normalized_laplacian = normalized_laplacian
95 | self.laplacian = utils.get_laplacian(sparse_adj_mat=sparse_adj_mat, normalized_laplacian=normalized_laplacian)
96 |
97 | self.matern_kernel = get_matern_kernel(self.laplacian, self.nu, self.kappa)
98 |
99 | def K(self, X, Y=None, presliced=False):
100 | X = tf.reshape(tf.cast(X, tf.int32), [-1])
101 | X2 = tf.reshape(tf.cast(Y, tf.int32), [-1]) if Y is not None else X
102 |
103 | cov = self.variance * self.matern_kernel
104 | cov = tf.gather(tf.gather(cov, X, axis=0), X2, axis=1)
105 | return cov
106 |
107 | def K_diag(self, X, presliced=False):
108 | return tf.linalg.diag_part(self.K(X, presliced=presliced))
109 |
110 |
111 | class WaveKernel(gpflow.kernels.base.Kernel):
112 | def __init__(self, sparse_adj_mat, variance=1.0, beta=0.1, c1=1, c2=1):
113 | super().__init__()
114 | self.variance = Parameter(variance, transform=gpflow.utilities.positive(),
115 | name="variance")
116 | self.c1 = Parameter(c1, transform=gpflow.utilities.positive(),
117 | name="c1")
118 | self.c2 = Parameter(c2, transform=gpflow.utilities.positive(),
119 | name="c2")
120 |
121 | self.beta = beta
122 |
123 | self.laplacian = utils.get_normalized_laplacian(sparse_adj_mat)
124 | self.sqrt_lapl = tf.constant(scipy.linalg.sqrtm(self.laplacian.numpy()), dtype=tf.float64)
125 |
126 | self.sqrt_inv_lapl = tf.constant(
127 | np.linalg.pinv(self.sqrt_lapl), dtype=tf.float64)
128 | self.sin = self.sqrt_inv_lapl @ scipy.linalg.sinm(self.sqrt_lapl * self.beta)
129 | self.cov = None
130 |
131 | def K(self, X, Y=None, presliced=False):
132 | X = tf.reshape(tf.cast(X, tf.int32), [-1])
133 | X2 = tf.reshape(tf.cast(Y, tf.int32), [-1]) if Y is not None else X
134 | self.sin = self.sqrt_inv_lapl @ scipy.linalg.sinm(self.sqrt_lapl * self.c1)
135 | self.cov = self.variance * self.sin
136 | self.cov = tf.gather(tf.gather(self.cov, X, axis=0), X2, axis=1)
137 | return self.cov
138 |
139 | def K_diag(self, X, presliced=False):
140 | return tf.linalg.diag_part(self.K(X, presliced=presliced))
141 |
--------------------------------------------------------------------------------
/graph_kernels/time_kernels.py:
--------------------------------------------------------------------------------
1 | import gpflow
2 | from gpflow import Parameter
3 | import tensorflow as tf
4 |
5 | import scipy
6 | from scipy import sparse
7 |
8 | import networkx as nx
9 |
10 | from . import utils
11 | from . import kernels
12 |
13 |
14 | def get_adj_matrix(graph):
15 | element = list(graph.nodes())[0]
16 | if nx.is_weighted(graph):
17 | return sparse.csr_matrix(nx.linalg.attrmatrix.attr_matrix(graph, "weight", rc_order=graph.nodes()))
18 | else:
19 | if isinstance(element, int):
20 | return nx.adjacency_matrix(graph, nodelist=range(len(graph.nodes())))
21 | else:
22 | return nx.adjacency_matrix(graph, nodelist=graph.nodes())
23 |
24 |
25 | class TimeDistributed1dExponentialKernel(gpflow.kernels.base.Kernel):
26 | def __init__(self, graph, variance=1.0):
27 | super().__init__()
28 | self.variance = Parameter(variance, transform=gpflow.utilities.positive(), name="variance")
29 | self.graph = graph
30 | self.time_kernel = gpflow.kernels.Exponential()
31 | self.graph_kernel = gpflow.kernels.Exponential()
32 | gpflow.set_trainable(self.graph_kernel.variance, False)
33 | gpflow.set_trainable(self.time_kernel.variance, False)
34 |
35 | def K(self, X, Y=None, presliced=False):
36 | t = tf.reshape(tf.cast(X[:, -1], tf.float64), [X.shape[0], 1])
37 | X = tf.cast(X[:, :-1], tf.float64)
38 |
39 | if Y is not None:
40 | t2 = tf.reshape(tf.cast(Y[:, -1], tf.float64), [Y.shape[0], 1])
41 | X2 = tf.cast(Y[:, :-1], tf.float64)
42 | else:
43 | t2 = t
44 | X2 = X
45 |
46 | cov = self.variance * self.time_kernel(t, t2) * self.graph_kernel(X, X2)
47 | return cov
48 |
49 | def K_diag(self, X, presliced=False):
50 | return tf.linalg.diag_part(self.K(X, presliced=presliced))
51 |
52 |
53 | class TimeDistributedGraphKernel(gpflow.kernels.base.Kernel):
54 | def __init__(self, graph, graph_kernel, variance=1.0, time_kernel_class=None):
55 | super().__init__()
56 | self.variance = Parameter(variance, transform=gpflow.utilities.positive(), name="variance")
57 | self.graph = graph
58 | self.time_kernel = time_kernel_class() if time_kernel_class is not None else gpflow.kernels.RBF()
59 | self.graph_kernel = graph_kernel
60 | gpflow.set_trainable(self.graph_kernel.variance, False)
61 | gpflow.set_trainable(self.time_kernel.variance, False)
62 |
63 | # Input: (node_id, time)
64 | def K(self, X, Y=None, presliced=False):
65 | t = tf.reshape(tf.cast(X[:, -1], tf.float64), [X.shape[0], 1])
66 | X = tf.cast(X[:, :-1], tf.float64)
67 |
68 | if Y is not None:
69 | t2 = tf.reshape(tf.cast(Y[:, -1], tf.float64), [Y.shape[0], 1])
70 | X2 = tf.cast(Y[:, :-1], tf.float64)
71 | else:
72 | t2 = t
73 | X2 = X
74 |
75 | cov = self.variance * self.time_kernel(t, t2) * self.graph_kernel(X, X2)
76 | return cov
77 |
78 | def K_diag(self, X, presliced=False):
79 | return tf.linalg.diag_part(self.K(X, presliced=presliced))
80 |
81 |
82 | class TimeDistributedLaplacianKernel(TimeDistributedGraphKernel):
83 | def __init__(self, graph, variance=1.0, time_kernel_class=None, normalized_laplacian=True):
84 | sparse_adj_matrix = get_adj_matrix(graph)
85 | graph_kernel = kernels.LaplacianKernel(sparse_adj_matrix, normalized_laplacian=normalized_laplacian)
86 | super().__init__(graph, graph_kernel, variance, time_kernel_class)
87 |
88 |
89 | class TimeDistributedMaternKernel(TimeDistributedGraphKernel):
90 | def __init__(self, graph, nu, kappa, variance=1.0, normalized_laplacian=True, time_kernel_class=None):
91 | sparse_adj_matrix = get_adj_matrix(graph)
92 | graph_kernel = kernels.MaternKernel(
93 | sparse_adj_matrix,
94 | nu, kappa, normalized_laplacian=normalized_laplacian)
95 | super().__init__(graph, graph_kernel, variance, time_kernel_class)
96 |
97 |
98 | class TimeDistributedDiffusionKernel(TimeDistributedGraphKernel):
99 | def __init__(self, graph, variance=1.0, normalized_laplacian=True, time_kernel_class=None):
100 | sparse_adj_matrix = get_adj_matrix(graph)
101 | graph_kernel = kernels.DiffusionKernel(
102 | sparse_adj_matrix, normalized_laplacian=normalized_laplacian)
103 | super().__init__(graph, graph_kernel, variance, time_kernel_class)
104 |
105 |
106 | class TimeDistributedRandomWalkKernel(TimeDistributedGraphKernel):
107 | def __init__(self, graph, variance=1.0, time_kernel_class=None):
108 | sparse_adj_matrix = get_adj_matrix(graph)
109 | graph_kernel = kernels.RandomWalkKernel(sparse_adj_matrix)
110 | super().__init__(graph, graph_kernel, variance, time_kernel_class)
111 |
112 |
113 | def get_inds(t_indices, t2_indices, X, X2):
114 | # returns [(t1, t2, x1, x2)]
115 | left = tf.concat([t_indices, tf.cast(X, dtype=tf.int64)], axis=-1)
116 | right = tf.concat([t2_indices, tf.cast(X2, dtype=tf.int64)], axis=-1)
117 | inds = utils.cartesian_product(left, right)
118 | inds = tf.gather(inds, [0, 2, 1, 3], axis=2)
119 | return inds
120 |
121 |
122 | def get_exponents_tf(vals, Gamma):
123 | return tf.linalg.expm(tf.tensordot(vals, -Gamma, axes=0))
124 |
125 |
126 | def get_exponents(vals, Gamma):
127 | unique_vals = tf.sort(tf.unique(tf.reshape(vals, [-1]))[0])
128 | int_dist = tf.reshape(
129 | tf.where(
130 | tf.equal(tf.reshape(vals, [-1])[:, tf.newaxis], unique_vals[tf.newaxis, :]))[:, 1], vals.shape)
131 | unique_vals = tf.constant(unique_vals, dtype=tf.float64)
132 | return tf.gather(tf.linalg.expm(tf.tensordot(unique_vals, -Gamma, axes=0)), int_dist)
133 |
134 |
135 | def get_exponents_scalar_tf(vals, lambdas):
136 | return tf.math.exp(tf.tensordot(-lambdas, vals, axes=0))
137 |
138 |
139 | def get_exponents_scalar(vals, lambdas):
140 | unique_vals = tf.sort(tf.unique(tf.reshape(vals, [-1]))[0])
141 | int_dist = tf.reshape(
142 | tf.where(
143 | tf.equal(tf.reshape(vals, [-1])[:, tf.newaxis], unique_vals[tf.newaxis, :]))[:, 1], vals.shape)
144 | unique_vals = tf.constant(unique_vals, dtype=tf.float64)
145 | return tf.gather(tf.math.exp(tf.tensordot(-lambdas, unique_vals, axes=0)), int_dist, axis=1)
146 |
147 |
148 | # calculating exp(-lambda (t + s))
149 | def get_sums_exps(unique_t, unique_t2, lambdas):
150 | # we use only diagonal elements because consider diagonal matrix \Sigma
151 | unique_t, unique_t2 = tf.squeeze(unique_t), tf.squeeze(unique_t2)
152 | time_pairwise_sums = unique_t[:, None] + unique_t2[None, :]
153 | time_pairwise_sums = tf.tensordot(-lambdas, time_pairwise_sums, axes=0)
154 | return tf.math.exp(time_pairwise_sums)
155 |
156 |
157 | # calculating a solution for stochastic heat equation
158 | # for diagonal variance
159 | def get_covariance_solution(dists_exps, sums_exps, variance, u, gamma_s):
160 | mult = tf.math.pow(tf.linalg.diag(variance), 2)
161 | pair_sums = utils.replace_small_values(gamma_s[None, :] + gamma_s[:, None], 1e-7)
162 | G = tf.linalg.diag_part(tf.math.divide(mult, pair_sums))[:, tf.newaxis, tf.newaxis, ] *\
163 | (dists_exps - sums_exps)
164 | G = tf.linalg.diag(tf.transpose(G, [1, 2, 0]))
165 | return u @ G @ tf.transpose(u)
166 |
167 |
168 | def get_covariance_solution_fixed(t, s, u, variance, lambdas):
169 | sigma = tf.linalg.diag(variance)
170 | mult = tf.transpose(u) @ sigma @ tf.transpose(sigma) @ u
171 | pair_sums = lambdas[None, :] + lambdas[:, None]
172 | mult = tf.math.divide(mult, pair_sums)
173 |
174 | lt = lambdas[:, None] @ t[None, :]
175 | ls = lambdas[:, None] @ s[None, :]
176 | pairwise_sums = lt[:, :, None, None] + ls[None, None, :, :]
177 | pairwise_sums = tf.transpose(pairwise_sums, [0, 2, 1, 3])
178 |
179 | mins = tf.math.minimum(t[:, None], s[None, :])
180 | left = tf.math.exp(pair_sums[:, :, None, None] * mins[None, None, :, :] - pairwise_sums)
181 |
182 | right = tf.math.exp(-pairwise_sums)
183 | G = mult[:, :, None, None] * (left - right)
184 | G = tf.transpose(G, [2, 3, 0, 1])
185 | return u @ G @ tf.transpose(u)
186 |
187 |
188 | class StochasticHeatEquation(gpflow.kernels.base.Kernel):
189 | def __init__(self, graph, variance=1.0, c=1, normalized_laplacian=True,
190 | use_pseudodifferential=False, nu=None, kappa=None):
191 | super().__init__()
192 | self.variance = Parameter(variance, transform=gpflow.utilities.positive(1e-4), name="variance")
193 | self.c = Parameter(c, transform=gpflow.utilities.positive(1e-4), name="diffusion")
194 | self.graph = graph
195 | if nx.is_weighted(graph):
196 | self.laplacian = utils.get_laplacian(
197 | sparse.csr_matrix(nx.linalg.attrmatrix.attr_matrix(graph, "weight", rc_order=graph.nodes())),
198 | normalized_laplacian)
199 | else:
200 | self.laplacian = utils.get_laplacian(nx.adjacency_matrix(graph), normalized_laplacian)
201 |
202 | self.use_pseudodifferential = use_pseudodifferential
203 | if use_pseudodifferential:
204 | self.nu = nu
205 | self.kappa = kappa
206 | else:
207 | self.nu = None
208 | self.kappa = None
209 |
210 | # laplacian = self.u @ tf.linalg.diag(self.laplacian_s) @ tf.transpose(self.v)
211 | self.laplacian_s, self.u, self.v = tf.linalg.svd(self.laplacian)
212 |
213 | def get_scaled_differential_s(self):
214 | if self.use_pseudodifferential:
215 | return self.c * ((2 * self.nu) / (self.kappa ** 2) + self.laplacian_s) ** (self.nu / 2)
216 | else:
217 | return self.c * self.laplacian_s
218 |
219 | def get_scaled_differential(self):
220 | if self.use_pseudodifferential:
221 | return self.u @ tf.linalg.diag(self.get_scaled_differential_s()) @ tf.transpose(self.u)
222 | else:
223 | return self.c * self.laplacian
224 |
225 | # Input: (node_id, time)
226 | def K(self, X, Y=None, presliced=False):
227 | t = tf.reshape(tf.cast(X[:, -1], tf.float64), [X.shape[0]])
228 | X = tf.cast(X[:, :-1], tf.float64)
229 | if Y is not None:
230 | t2 = tf.reshape(tf.cast(Y[:, -1], tf.float64), [Y.shape[0]])
231 | X2 = tf.cast(Y[:, :-1], tf.float64)
232 | else:
233 | t2 = t
234 | X2 = X
235 |
236 | unique_t = tf.sort(tf.unique(t)[0])[:, tf.newaxis]
237 | unique_t2 = tf.sort(tf.unique(t2)[0])[:, tf.newaxis]
238 |
239 | self.time_pairwise_distances = tf.abs(unique_t - tf.transpose(unique_t2))
240 | self.time_pairwise_sums = (unique_t + tf.transpose(unique_t2))
241 | Gamma = self.get_scaled_differential()
242 |
243 | gamma_s = self.get_scaled_differential_s()
244 | if len(self.variance.shape) > 0:
245 | cov = get_covariance_solution_fixed(
246 | tf.squeeze(unique_t), tf.squeeze(unique_t2), self.u, self.variance, gamma_s)
247 | else:
248 | left_part = get_exponents(self.time_pairwise_distances, Gamma)
249 | right_part = get_exponents(self.time_pairwise_sums, Gamma)
250 | cov = self.variance * (left_part - right_part) @ tf.linalg.pinv(Gamma)
251 | t_indices = tf.where(tf.transpose(tf.equal(t, unique_t)))[:, 1]
252 | t2_indices = tf.where(tf.transpose(tf.equal(t2, unique_t2)))[:, 1]
253 |
254 | t_indices = tf.expand_dims(t_indices, 1)
255 | t2_indices = tf.expand_dims(t2_indices, 1)
256 |
257 | inds = get_inds(t_indices, t2_indices, X, X2)
258 | cov = tf.gather_nd(cov, inds)
259 | return cov
260 |
261 | def K_diag(self, X, presliced=False):
262 | return tf.linalg.diag_part(self.K(X, presliced=presliced))
263 |
264 |
265 | def get_cosines(vals, Gamma):
266 | unique_vals = tf.sort(tf.unique(tf.reshape(vals, [-1]))[0])
267 | int_dist = tf.reshape(
268 | tf.where(
269 | tf.equal(tf.reshape(vals, [-1])[:, tf.newaxis], unique_vals[tf.newaxis, :]))[:, 1], vals.shape)
270 | result = utils.tf_cosm(tf.tensordot(unique_vals, Gamma, axes=0))
271 | return tf.gather(result, int_dist)
272 |
273 |
274 | def get_sines(vals, Gamma):
275 | unique_vals = tf.sort(tf.unique(tf.reshape(vals, [-1]))[0])
276 | int_dist = tf.reshape(
277 | tf.where(
278 | tf.equal(tf.reshape(vals, [-1])[:, tf.newaxis], unique_vals[tf.newaxis, :]))[:, 1], vals.shape)
279 | result = utils.tf_sinm(tf.tensordot(unique_vals, Gamma, axes=0))
280 | return tf.gather(result, int_dist)
281 |
282 |
283 | def get_cosines_tf(vals, Gamma):
284 | return utils.tf_cosm(tf.tensordot(vals, Gamma, axes=0))
285 |
286 |
287 | def get_sines_tf(vals, Gamma):
288 | return utils.tf_sinm(tf.tensordot(vals, Gamma, axes=0))
289 |
290 |
291 | class StochasticWaveEquationKernel(gpflow.kernels.base.Kernel):
292 | def __init__(self, graph, variance=1.0, c=1., normalized_laplacian=True, use_pseudodifferential=False,
293 | nu=None, kappa=None):
294 | super().__init__()
295 | self.variance = Parameter(variance, transform=gpflow.utilities.positive(), name="variance")
296 | self.c = Parameter(c, transform=gpflow.utilities.positive(1e-2), name="propagation speed")
297 | self.graph = graph
298 | self.laplacian = utils.get_laplacian(nx.adjacency_matrix(graph), normalized_laplacian)
299 | if use_pseudodifferential:
300 | self.nu = nu
301 | self.kappa = kappa
302 | self.laplacian_s, self.u, self.v = tf.linalg.svd(self.laplacian)
303 | else:
304 | self.nu = None
305 | self.kappa = None
306 | self.id_l = tf.eye(self.laplacian.shape[0], dtype=tf.float64)
307 |
308 | # Input: (node_id, time)
309 | def K(self, X, Y=None, presliced=False):
310 | s = ((2 * self.nu) / (self.kappa ** 2) + self.laplacian_s) ** (self.nu / 2)
311 | self.laplacian = self.u @ tf.linalg.diag(s) @ tf.transpose(self.u)
312 | s = ((2 * self.nu) / (self.kappa ** 2) + self.laplacian_s) ** (self.nu / 4)
313 | self.sqrt_lapl = self.u @ tf.linalg.diag(s) @ tf.transpose(self.u)
314 | self.laplacian_inv = tf.linalg.pinv(self.laplacian)
315 |
316 | t = tf.reshape(tf.cast(X[:, -1], tf.float64), [X.shape[0]])
317 | X = tf.cast(X[:, :-1], tf.float64)
318 | if Y is not None:
319 | t2 = tf.reshape(tf.cast(Y[:, -1], tf.float64), [Y.shape[0]])
320 | X2 = tf.cast(Y[:, :-1], tf.float64)
321 | else:
322 | t2 = t
323 | X2 = X
324 | unique_t = tf.sort(tf.unique(t)[0])[:, tf.newaxis]
325 | unique_t2 = tf.sort(tf.unique(t2)[0])[:, tf.newaxis]
326 | time_pairwise_distances = tf.abs(unique_t - tf.transpose(unique_t2))
327 |
328 | theta = self.c * self.sqrt_lapl
329 | # Gamma = (self.c**2) * self.laplacian
330 | mins = tf.math.minimum(unique_t, tf.transpose(unique_t2))
331 | maxs = tf.math.maximum(unique_t, tf.transpose(unique_t2))
332 | # gamma_inv = tf.linalg.pinv(Gamma)
333 | gamma_inv = (1 / self.c**2) * self.laplacian_inv
334 | if len(self.variance.shape) > 0:
335 | raise Exception("Not implemented for matrix variance")
336 | else:
337 | gamma_inv = self.variance * gamma_inv
338 | cov = gamma_inv @ get_cosines(time_pairwise_distances, theta)
339 | cov = tf.tensordot(mins, self.id_l, axes=0) @ cov - 0.5 *\
340 | gamma_inv @ get_cosines(maxs, theta) @ get_sines(mins, theta) @ tf.linalg.inv(theta)
341 |
342 | t_indices = tf.where(tf.transpose(tf.equal(t, unique_t)))[:, 1]
343 | t2_indices = tf.where(tf.transpose(tf.equal(t2, unique_t2)))[:, 1]
344 |
345 | t_indices = tf.expand_dims(t_indices, 1)
346 | t2_indices = tf.expand_dims(t2_indices, 1)
347 |
348 | inds = get_inds(t_indices, t2_indices, X, X2)
349 | cov = tf.gather_nd(cov, inds)
350 | return cov
351 |
352 | def K_diag(self, X, presliced=False):
353 | return tf.linalg.diag_part(self.K(X, presliced=presliced))
354 |
--------------------------------------------------------------------------------
/graph_kernels/utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | import json
3 | import random
4 | import pickle
5 | import networkx as nx
6 | import tensorflow as tf
7 | import numpy as np
8 |
9 | from matplotlib import pyplot as plt
10 | import seaborn as sns
11 |
12 | import sklearn.metrics
13 | from sklearn.utils import shuffle
14 | from sklearn.decomposition import PCA
15 | from sklearn.manifold import TSNE
16 |
17 | import gpflow
18 | from gpflow import Parameter
19 |
20 | from . import data_utils
21 |
22 |
23 | def sparse_mat_to_sparse_tensor(sparse_mat):
24 | """
25 | Converts a scipy csr_matrix to a tensorflow SparseTensor.
26 | """
27 | coo = sparse_mat.tocoo()
28 | indices = np.stack([coo.row, coo.col], axis=-1)
29 | tensor = tf.sparse.SparseTensor(indices, sparse_mat.data, sparse_mat.shape)
30 | return tensor
31 |
32 |
33 | def normalize_laplacian(laplacian, d):
34 | inv_d = tf.linalg.diag([1. / float(el) if el != 0 else 0 for el in tf.linalg.diag_part(d)])
35 | inv_d = tf.cast(inv_d, dtype=tf.float64)
36 | inv_sqrt_d = tf.pow(inv_d, 0.5)
37 | laplacian_normalized = tf.linalg.matmul(inv_sqrt_d, laplacian)
38 | laplacian_normalized = tf.linalg.matmul(laplacian_normalized, inv_sqrt_d)
39 | return laplacian_normalized
40 |
41 |
42 | def get_non_normalized_laplacian(sparse_adj_mat):
43 | sparse_adj_mat = sparse_mat_to_sparse_tensor(sparse_adj_mat)
44 | sparse_adj_mat = tf.cast(sparse_adj_mat, tf.float64)
45 |
46 | d_dense = tf.sparse.to_dense(tf.sparse.SparseTensor(
47 | indices=list(zip(*np.diag_indices(sparse_adj_mat.shape[0]))),
48 | values=tf.math.reduce_sum(tf.sparse.to_dense(sparse_adj_mat), axis=1),
49 | dense_shape=sparse_adj_mat.shape,
50 | ))
51 | laplacian_sparse = tf.math.subtract(
52 | d_dense, tf.sparse.to_dense(sparse_adj_mat))
53 |
54 | return laplacian_sparse, d_dense
55 |
56 |
57 | def get_normalized_laplacian(sparse_adj_mat):
58 | laplacian_sparse, d_dense = get_non_normalized_laplacian(sparse_adj_mat)
59 | return normalize_laplacian(laplacian_sparse, d_dense)
60 |
61 |
62 | def get_normalized_laplacian_from_graph(graph):
63 | return get_normalized_laplacian(
64 | nx.adjacency_matrix(graph, nodelist=range(len(graph.nodes())))
65 | )
66 |
67 |
68 | def get_non_normalized_laplacian_from_graph(graph):
69 | return get_non_normalized_laplacian(
70 | nx.adjacency_matrix(graph, nodelist=range(len(graph.nodes())))
71 | )[0]
72 |
73 |
74 | def get_laplacian(sparse_adj_mat, normalized_laplacian):
75 | if normalized_laplacian:
76 | return get_normalized_laplacian(sparse_adj_mat)
77 | else:
78 | return get_non_normalized_laplacian(sparse_adj_mat)[0]
79 |
80 |
81 | def get_dataset_ids_from_graph(G, tr_ratio, random_seed=42):
82 | N = len(G.nodes())
83 | ids = np.array(list(range(N)))
84 | ids = shuffle(ids, random_state=random_seed)
85 | tr_id = int(N * tr_ratio)
86 | A = nx.to_scipy_sparse_matrix(G)
87 | idx_train, idx_test = ids[:tr_id, np.newaxis], ids[tr_id:, np.newaxis]
88 | return A, idx_train, idx_test
89 |
90 |
91 | def evaluate_mse(X_val, y_val, gprocess):
92 | pred_y, pred_y_var = gprocess.predict_y(X_val)
93 | return sklearn.metrics.mean_squared_error(pred_y, y_val)
94 |
95 |
96 | def evaluate_mape_predictions(pred_y, y_val, transformer=None):
97 | if transformer is not None:
98 | pred_y = transformer.inverse_transform(pred_y)
99 | y_val = transformer.inverse_transform(y_val)
100 | return sklearn.metrics.mean_absolute_percentage_error(y_val, pred_y)
101 |
102 |
103 | def evaluate_mape(X_val, y_val, gprocess, transformer=None):
104 | pred_y, pred_y_var = gprocess.predict_y(X_val)
105 | return evaluate_mape_predictions(pred_y, y_val, transformer)
106 |
107 |
108 | def evaluate_mae_predictions(pred_y, y_val, transformer=None):
109 | if transformer is not None:
110 | pred_y = transformer.inverse_transform(pred_y)
111 | y_val = transformer.inverse_transform(y_val)
112 | return sklearn.metrics.mean_absolute_error(y_val, pred_y)
113 |
114 |
115 | def evaluate_mae(X_val, y_val, gprocess, transformer=None):
116 | pred_y, pred_y_var = gprocess.predict_y(X_val)
117 | return evaluate_mae_predictions(pred_y, y_val, transformer)
118 |
119 |
120 | def smape(y_pred, y_true):
121 | return 100 / len(y_pred) * np.sum(2 * np.abs(y_true - y_pred) / (np.abs(y_pred) + np.abs(y_true)))
122 |
123 |
124 | def plot(m, X_train, signal):
125 | xmin, xmax = 0.0, 30
126 | xx = np.linspace(xmin, xmax, 100)[:, None]
127 | mean, var = m.predict_y(xx)
128 | var = np.array([max(float(var[i]), 1e-3) for i in range(var.shape[0])])[:, np.newaxis]
129 | plt.figure(figsize=(12, 6))
130 | plt.plot(X_train, signal[[int(el) for el in X_train[:, 0]]], 'kx', mew=2)
131 | plt.plot(xx, mean, 'b', lw=2)
132 | plt.fill_between(xx[:, 0], mean[:, 0] - 2 * np.sqrt(var[:, 0]), mean[:, 0] + 2 * np.sqrt(var[:, 0]), color='blue', alpha=0.2)
133 | plt.xlim(xmin, xmax)
134 | plt.title("Adjacency matrix covariance function")
135 |
136 |
137 | def visualize_gprocess(gprocess, X_train, X_test, G, signal, layout=None):
138 | X_all = tf.concat((X_train, X_test), axis=0)
139 | y_pred, var = gprocess.predict_y(X_all)
140 |
141 | y_pred_unshuffle = [0] * len(G.nodes())
142 | for i, y in zip(X_all, y_pred):
143 | y_pred_unshuffle[int(i)] = float(y)
144 | data_utils.plot_nodes_with_colors(G, y_pred_unshuffle, layout=layout)
145 | plot(gprocess, X_train, signal)
146 |
147 |
148 | def training_step(X_train, y_train, optimizer, gprocess, natgrad=None):
149 | loss_fn = gprocess.training_loss_closure((X_train, y_train), compile=False)
150 | optimizer.minimize(loss_fn, var_list=gprocess.trainable_variables)
151 | if natgrad is not None:
152 | natgrad.minimize(loss_fn, var_list=[(gprocess.q_mu, gprocess.q_sqrt)])
153 |
154 | return -gprocess.elbo((X_train, y_train))
155 |
156 |
157 | def cartesian_product(a, b):
158 | a_ = tf.reshape(tf.tile(a, [1, b.shape[0]]), (a.shape[0] * b.shape[0], a.shape[1]))
159 | b_ = tf.tile(b, [a.shape[0], 1])
160 |
161 | return tf.reshape(tf.concat([a_, b_], 1), [a.shape[0], b.shape[0], 4])
162 |
163 |
164 | def is_pos_semi_def(x):
165 | return np.all(np.array(np.linalg.eigvals(x), dtype=np.float64) >= -1e-7)
166 |
167 |
168 | def save_model_to_hyperparameters(model, save_path="gprocess_hyperparams.pkl"):
169 | pickle.dump(gpflow.utilities.parameter_dict(model), open(save_path, "wb"))
170 |
171 |
172 | # loaded_result = loaded_model.predict_f_compiled(samples_input)
173 | def load_model(model, path):
174 | params = pickle.load(open(path, "rb"))
175 | gpflow.utilities.multiple_assign(model, params)
176 | return model
177 |
178 |
179 | def set_all_random_seeds(random_seed):
180 | tf.compat.v1.reset_default_graph()
181 | tf.keras.backend.clear_session()
182 |
183 | tf.random.set_seed(random_seed)
184 | random.seed(random_seed)
185 | np.random.seed(random_seed)
186 |
187 |
188 | class ConstantArray(gpflow.mean_functions.MeanFunction):
189 | def __init__(self, shape):
190 | super().__init__()
191 | c = tf.zeros(shape)
192 | self.c = Parameter(c, name="constant array mean")
193 |
194 | def __call__(self, X):
195 | return tf.reshape(tf.gather(self.c, tf.cast(X[:, 0], dtype=tf.int32)), (X.shape[0], 1))
196 |
197 |
198 | def tf_cosm(A):
199 | return tf.math.real(tf.linalg.expm(1j * tf.cast(A, dtype=tf.complex128)))
200 |
201 |
202 | def tf_sinm(matrix):
203 | if matrix.dtype.is_complex:
204 | j_matrix = 1j * matrix
205 | return -0.5j * (tf.linalg.expm(j_matrix) - tf.linalg.expm(-j_matrix))
206 | else:
207 | j_matrix = tf.complex(tf.zeros_like(matrix), matrix)
208 | return tf.math.imag(tf.linalg.expm(j_matrix))
209 |
210 |
211 | class Callback:
212 | def __init__(self, model, Xtrain, Ytrain, Xtest, Ytest, loss_fn=None, transformer=None):
213 | self.model = model
214 | self.Xtrain = Xtrain
215 | self.Ytrain = Ytrain
216 | self.Xtest = Xtest
217 | self.Ytest = Ytest
218 | self.transformer = transformer
219 | self.epoch = 0
220 | self.loss_fn = loss_fn
221 |
222 | def __call__(self, step=None, variables=None, values=None):
223 | mape = evaluate_mape(self.Xtest, self.Ytest, self.model, transformer=self.transformer)
224 | mae = evaluate_mae(self.Xtest, self.Ytest, self.model, transformer=self.transformer)
225 | if self.loss_fn is None:
226 | elbo = self.model.elbo((self.Xtrain, self.Ytrain)).numpy()
227 | else:
228 | elbo = self.loss_fn()
229 |
230 | print(f"{self.epoch}:\tELBO: {elbo:.5f}\tMAPE: {mape:.10f}\tMAE: {mae:.10f}")
231 | self.epoch += 1
232 |
233 |
234 | def replace_small_values(tensor, eps=1e-7):
235 | return tf.where(
236 | tf.abs(tensor) < eps,
237 | tf.ones_like(tensor), tensor)
238 |
239 |
240 | def get_hmc_sample(num_samples, samples, hmc_helper, model, test_X):
241 | f_samples = []
242 | for i in range(num_samples):
243 | if i % 10 == 0:
244 | print(i)
245 | # Note that hmc_helper.current_state contains the unconstrained variables
246 | for var, var_samples in zip(hmc_helper.current_state, samples):
247 | var.assign(var_samples[i])
248 | f = model.predict_f_samples(test_X, 5)
249 | f_samples.append(f)
250 | f_samples = np.vstack(f_samples)
251 | return f_samples
252 |
--------------------------------------------------------------------------------
/graph_kernels/utils_opt.py:
--------------------------------------------------------------------------------
1 | import gpflow
2 | import tensorflow as tf
3 | import tensorflow_probability as tfp
4 | from tensorflow_probability import distributions as tfd
5 | import numpy as np
6 |
7 | from . import utils
8 |
9 |
10 | gpflow.config.set_default_float(tf.float64)
11 | f64 = gpflow.utilities.to_default_float
12 |
13 |
14 | def optimize_ada_natgrad(gprocess, train_X, train_y, test_X, test_y, n_iter, learning_rate=1e-2,
15 | transformer=None):
16 | decayed_lr = tf.keras.optimizers.schedules.ExponentialDecay(
17 | learning_rate, 5000, 0.5, staircase=False, name=None
18 | )
19 | optimizer = tf.optimizers.Adam(decayed_lr)
20 |
21 | gpflow.set_trainable(gprocess.q_mu, False)
22 | gpflow.set_trainable(gprocess.q_sqrt, False)
23 | natgrad_opt = gpflow.optimizers.NaturalGradient(gamma=0.1)
24 |
25 | result = {}
26 | for epoch in range(n_iter):
27 | elbo = -utils.training_step(
28 | train_X, train_y, optimizer, gprocess, natgrad_opt).numpy()
29 |
30 | mape = utils.evaluate_mape(test_X, test_y, gprocess, transformer)
31 | mae = utils.evaluate_mae(test_X, test_y, gprocess, transformer)
32 | result[epoch] = {
33 | "ELBO": elbo,
34 | "MAPE": mape,
35 | "MAE": mae,
36 | }
37 | print(f"{epoch}:\tELBO: {elbo:.5f}\tMAPE: {mape:.10f}\tMAE: {mae:.10f}")
38 |
39 | return result, gprocess
40 |
41 |
42 | def optimize_lbfgs_b(gprocess, train_X, train_y, test_X, test_y, n_iter, transformer=None, compile=False):
43 | optimizer = gpflow.optimizers.Scipy()
44 | loss_fn = gprocess.training_loss_closure((train_X, train_y), compile=compile)
45 | callback = utils.Callback(
46 | gprocess, train_X, train_y, test_X, test_y,
47 | transformer=transformer)
48 | optimizer.minimize(
49 | loss_fn,
50 | variables=gprocess.trainable_variables,
51 | compile=compile,
52 | options=dict(disp=True, maxiter=n_iter),
53 | step_callback=callback,
54 | )
55 |
56 | mape = utils.evaluate_mape(test_X, test_y, gprocess, transformer=transformer)
57 | mae = utils.evaluate_mae(test_X, test_y, gprocess, transformer=transformer)
58 | elbo = loss_fn()
59 | result = {"ELBO": elbo.numpy(), "MAPE": mape, "MAE": mae}
60 | return result, gprocess
61 |
62 |
63 | def evaluate_kernel_svgp(kernel, train_X, train_y, test_X, test_y, graph, transformer=None,
64 | dump_everything=False, dump_directory=None, optimizer_name="Adam", n_iter=2000,
65 | mean_function=None, compile=False):
66 | # Optimizer = Adam or LBFGS
67 | if mean_function is None:
68 | mean_function = utils.ConstantArray(len(graph.nodes()))
69 |
70 | gprocess = gpflow.models.SVGP(
71 | kernel, gpflow.likelihoods.Gaussian(),
72 | inducing_variable=train_X, mean_function=mean_function, whiten=True, q_diag=False)
73 | gpflow.set_trainable(gprocess.inducing_variable, False)
74 | gprocess.likelihood.variance.assign(1e-2)
75 | if optimizer_name == "Adam":
76 | result, gprocess = optimize_ada_natgrad(
77 | gprocess, train_X, train_y,
78 | test_X, test_y, n_iter=n_iter, transformer=transformer)
79 | elif optimizer_name == "LBFGS":
80 | result, gprocess = optimize_lbfgs_b(
81 | gprocess, train_X, train_y, test_X, test_y, n_iter=n_iter, transformer=transformer,
82 | compile=compile)
83 | else:
84 | raise ValueError("Supported optimizers: Adam & LBFGS")
85 | return result, gprocess
86 |
87 |
88 | def initialize_hmc_helpers(model):
89 | hmc_helper = gpflow.optimizers.SamplingHelper(
90 | model.log_posterior_density, model.trainable_parameters,
91 | )
92 | hmc = tfp.mcmc.HamiltonianMonteCarlo(
93 | target_log_prob_fn=hmc_helper.target_log_prob_fn,
94 | num_leapfrog_steps=10, step_size=0.01
95 | )
96 | adaptive_hmc = tfp.mcmc.SimpleStepSizeAdaptation(
97 | hmc, num_adaptation_steps=10, target_accept_prob=0.75, adaptation_rate=0.1
98 | )
99 | return hmc_helper, hmc, adaptive_hmc
100 |
101 |
102 | def evaluate_kernel_mcmc(kernel, train_X, train_y, test_X, test_y, graph,
103 | mean_function=None, transformer=None,
104 | dump_everything=False, dump_directory=None, optimizer_name="Adam", n_iter=2000,
105 | num_burnin_steps=300, num_samples=500, full_mcmc=False):
106 | if mean_function is None:
107 | mean_function = utils.ConstantArray(len(graph.nodes()))
108 | model = gpflow.models.GPR((train_X, train_y), kernel, mean_function, noise_variance=0.01)
109 |
110 | model.likelihood.variance.prior = tfd.Normal(f64(0.0), f64(1e-2))
111 | for var in kernel.trainable_parameters:
112 | var.prior = tfd.Gamma(f64(1.0), f64(1.0))
113 | optimizer = gpflow.optimizers.Scipy()
114 | loss_fn = model.training_loss_closure(compile=False)
115 | callback = utils.Callback(
116 | model, train_X, train_y, test_X, test_y,
117 | loss_fn=loss_fn, transformer=transformer)
118 | optimizer.minimize(
119 | loss_fn, model.trainable_variables, compile=False,
120 | options=dict(disp=True, maxiter=n_iter), callback=callback)
121 | if full_mcmc:
122 | hmc_helper, hmc, adaptive_hmc = initialize_hmc_helpers(model)
123 | samples, traces = tfp.mcmc.sample_chain(
124 | num_results=num_samples,
125 | num_burnin_steps=num_burnin_steps,
126 | current_state=hmc_helper.current_state,
127 | kernel=adaptive_hmc,
128 | trace_fn=lambda _, pkr: pkr.inner_results.is_accepted,
129 | )
130 | print("Acceptance rate:", traces.is_accepted.numpy().mean())
131 |
132 | r_hat = tfp.mcmc.potential_scale_reduction(samples)
133 | print("R-hat diagnostic (per latent variable):", r_hat.numpy())
134 |
135 | #parameter_samples = hmc_helper.convert_to_constrained_values(samples)
136 | f_samples = utils.get_hmc_sample(num_samples, samples, hmc_helper, model, test_X)
137 | y_pred = np.median(f_samples, 0)
138 | mape = utils.evaluate_mape_predictions(y_pred, test_y, transformer)
139 | mae = utils.evaluate_mae_predictions(y_pred, test_y, transformer)
140 | elbo = loss_fn()
141 | else:
142 | mape = utils.evaluate_mape(test_X, test_y, model, transformer)
143 | mae = utils.evaluate_mae(test_X, test_y, model, transformer)
144 | elbo = loss_fn()
145 | return {"ELBO": elbo.numpy(), "MAPE": mape, "MAE": mae}, model
146 |
--------------------------------------------------------------------------------
/graph_kernels/utils_postproc.py:
--------------------------------------------------------------------------------
1 | import os
2 | import collections
3 | import json
4 | import numpy as np
5 | import scipy
6 | import scipy.stats
7 |
8 |
9 | def load_result(path, filename="result.json"):
10 | path = os.path.join(path, filename)
11 | result = json.load(open(path, "rb"))
12 | return {int(k): v for k, v in result.items()}
13 |
14 |
15 | def separate_results(result):
16 | all_iterations = []
17 | all_elbos = []
18 | all_mses = []
19 | for random_seed in result:
20 | cur_result = {int(k): v for k, v in result[random_seed].items()}
21 | iterations = sorted([int(i) for i in cur_result.keys()])
22 |
23 | all_iterations.extend(iterations)
24 | all_elbos.extend([cur_result[i]["ELBO"] for i in iterations])
25 | all_mses.extend([cur_result[i]["MSE"] for i in iterations])
26 | return all_iterations, all_elbos, all_mses
27 |
28 |
29 | def separate_results_all_with_iter(result):
30 | all_metrics = collections.defaultdict(list)
31 | for random_seed in result:
32 | if "time" in result[random_seed]:
33 | del result[random_seed]["time"]
34 | cur_result = {int(k): v for k, v in result[random_seed].items()}
35 | iterations = sorted([int(i) for i in cur_result.keys()])
36 |
37 | all_metrics["iterations"].extend(iterations)
38 | for k in cur_result[0].keys():
39 | all_metrics[k].extend([cur_result[i][k] for i in iterations])
40 | return all_metrics
41 |
42 |
43 | def separate_results_all(result):
44 | all_metrics = collections.defaultdict(list)
45 | for random_seed in result:
46 | for k in result[random_seed].keys():
47 | all_metrics[k].append(result[random_seed][k])
48 | return all_metrics
49 |
50 |
51 | def stats_array(data):
52 | mean = np.mean(data)
53 | # evaluate sample variance by setting delta degrees of freedom (ddof) to
54 | # 1. The degree used in calculations is N - ddof
55 | stddev = np.std(data, ddof=1)
56 | # Get the endpoints of the range that contains 95% of the distribution
57 | t_bounds = scipy.stats.t.interval(0.95, len(data) - 1)
58 | # sum mean to the confidence interval
59 | ci = [mean + critval * stddev / (len(data)**0.5) for critval in t_bounds]
60 | print("Mean: {:.4f} $\\pm$ {:.4f}".format(mean, ci[1] - mean))
61 | print("Confidence Interval 95%%: {}, {}".format(ci[0], ci[1]))
62 | print(scipy.stats.t.interval(0.95, len(data) - 1, loc=np.mean(data), scale=scipy.stats.sem(data)))
63 | print("Data: ", data)
64 |
65 |
66 | def print_statistics(result, n=2000):
67 | data = []
68 | for i, m in zip(result[0], result[2]):
69 | if i == n - 1:
70 | data.append(m)
71 | stats_array(data)
72 |
73 |
74 | def from_folder_to_results(folder):
75 | results_dict = {}
76 | for dir_name in os.listdir(folder):
77 | results_dict[dir_name] = load_result(os.path.join(folder, dir_name))
78 | return results_dict
79 |
80 |
81 | def parse_results(results):
82 | parsed_result = {}
83 | for k, v in results.items():
84 | parsed_result[k] = separate_results_all(v)
85 | return parsed_result
86 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | gpflow==2.1.4
2 | networkx==2.5
3 | tensorflow==2.5.1
4 | tensorflow-datasets==4.3.0
5 | tensorflow-hub==0.12.0
6 | tensorflow-probability==0.12.0
7 | gast==0.4.0
8 | scikit-learn==0.24.2
9 | numpy==1.19.5
10 | seaborn==0.11.0
11 | matplotlib==3.3.4
12 | scipy==1.5.4
13 | tqdm==4.62.3
14 | jupyter==1.0.0
15 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | from setuptools import setup
2 | from setuptools import find_packages
3 |
4 | setup(name='graph_kernels',
5 | version='0.0',
6 | description='Graph Kernels with GPFlow',
7 | author='',
8 | author_email='',
9 | url='',
10 | download_url='',
11 | license='Apache-2.0',
12 | install_requires=[],
13 | package_data={'graph_kernels': ['README.md']},
14 | packages=find_packages())
15 |
--------------------------------------------------------------------------------