├── .gitignore
├── README.md
├── dat
└── icews
│ └── undirected
│ └── 2003-D
│ └── data.npz
└── src
├── Makefile
├── icews_example.ipynb
├── impute.py
├── lambertw.pxd
├── lambertw.pyx
├── mcmc_model.pxd
├── mcmc_model.pyx
├── pgds.pyx
├── pp_plot.py
├── run_pgds.py
├── sample.pxd
├── sample.pyx
├── setup.py
└── test_pgds.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | env/
12 | build/
13 | develop-eggs/
14 | dist/
15 | downloads/
16 | eggs/
17 | .eggs/
18 | lib/
19 | lib64/
20 | parts/
21 | sdist/
22 | var/
23 | *.egg-info/
24 | .installed.cfg
25 | *.egg
26 |
27 | # PyInstaller
28 | # Usually these files are written by a python script from a template
29 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
30 | *.manifest
31 | *.spec
32 |
33 | # Installer logs
34 | pip-log.txt
35 | pip-delete-this-directory.txt
36 |
37 | # Unit test / coverage reports
38 | htmlcov/
39 | .tox/
40 | .coverage
41 | .coverage.*
42 | .cache
43 | nosetests.xml
44 | coverage.xml
45 | *,cover
46 | .hypothesis/
47 |
48 | # Translations
49 | *.mo
50 | *.pot
51 |
52 | # Django stuff:
53 | *.log
54 | local_settings.py
55 |
56 | # Flask stuff:
57 | instance/
58 | .webassets-cache
59 |
60 | # Scrapy stuff:
61 | .scrapy
62 |
63 | # Sphinx documentation
64 | docs/_build/
65 |
66 | # PyBuilder
67 | target/
68 |
69 | # IPython Notebook
70 | .ipynb_checkpoints
71 |
72 | # pyenv
73 | .python-version
74 |
75 | # celery beat schedule file
76 | celerybeat-schedule
77 |
78 | # dotenv
79 | .env
80 |
81 | # virtualenv
82 | venv/
83 | ENV/
84 |
85 | # Spyder project settings
86 | .spyderproject
87 |
88 | # Rope project settings
89 | .ropeproject
90 |
91 | ## Custom
92 | src/lambertw.c
93 | src/mcmc_model.c
94 | src/pgds.c
95 | src/sample.c
96 |
97 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Poisson-Gamma Dynamical Systems
2 | Source code for the paper: [Poisson-Gamma Dynamical Systems] (http://people.cs.umass.edu/~aschein/ScheinZhouWallach2016_paper.pdf) by Aaron Schein, Mingyuan Zhou, and Hanna Wallach, presented at NIPS 2016.
3 |
4 | The MIT License (MIT)
5 |
6 | Copyright (c) 2016 Aaron Schein
7 |
8 | Permission is hereby granted, free of charge, to any person obtaining a copy
9 | of this software and associated documentation files (the "Software"), to deal
10 | in the Software without restriction, including without limitation the rights
11 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
12 | copies of the Software, and to permit persons to whom the Software is
13 | furnished to do so, subject to the following conditions:
14 |
15 | The above copyright notice and this permission notice shall be included in all
16 | copies or substantial portions of the Software.
17 |
18 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
19 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
20 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
21 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
22 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
23 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
24 | SOFTWARE.
25 |
26 | ## What's included:
27 | * [pgds.pyx](https://github.com/aschein/pgds/blob/master/src/pgds.pyx): The main code file. Implements Gibbs sampling inference for PGDS.
28 | * [mcmc_model.pyx](https://github.com/aschein/pgds/blob/master/src/mcmc_model.pyx): Implements Cython interface for MCMC models. Inherited by pgds.pyx.
29 | * [sample.pyx](https://github.com/aschein/pgds/blob/master/src/sample.pyx): Implements fast Cython method for sampling various distributions.
30 | * [lambertw.pyx](https://github.com/aschein/pgds/blob/master/src/lambertw.pyx): Code for computing the Lambert-W function.
31 | * [Makefile](https://github.com/aschein/pgds/blob/master/src/Makefile): Makefile (cd into this directoy and type 'make' to compile).
32 | * [icews_example.ipynb](https://github.com/aschein/pgds/blob/master/src/icews_example.ipynb): Jupyter notebook with an examples of how to use the code to run PGDS on ICEWS data for exploratory and predictive analyses.
33 |
34 | ## Dependencies:
35 | * numpy
36 | * scipy
37 | * matplotlib
38 | * seaborn
39 | * pandas
40 | * argparse
41 | * path
42 | * scikit-learn
43 | * cython
44 | * GSL
45 |
--------------------------------------------------------------------------------
/dat/icews/undirected/2003-D/data.npz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aschein/pgds/d2344d10eae1f807379d589ce4ca527b8a4660f5/dat/icews/undirected/2003-D/data.npz
--------------------------------------------------------------------------------
/src/Makefile:
--------------------------------------------------------------------------------
1 | all:
2 | python setup.py build_ext -i
3 |
4 | clean:
5 | rm -r build; rm *.c; rm *.cpp; rm *.so; rm *.html; rm *.pyc
6 |
--------------------------------------------------------------------------------
/src/icews_example.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": 33,
6 | "metadata": {
7 | "collapsed": true
8 | },
9 | "outputs": [],
10 | "source": [
11 | "import numpy as np\n",
12 | "import numpy.random as rn\n",
13 | "%matplotlib notebook\n",
14 | "import matplotlib.pyplot as plt\n",
15 | "import seaborn as sns\n",
16 | "import pandas as pd\n",
17 | "from pgds import PGDS"
18 | ]
19 | },
20 | {
21 | "cell_type": "code",
22 | "execution_count": 34,
23 | "metadata": {
24 | "collapsed": true
25 | },
26 | "outputs": [],
27 | "source": [
28 | "data_dict = np.load('../dat/icews/undirected/2003-D/data.npz')"
29 | ]
30 | },
31 | {
32 | "cell_type": "code",
33 | "execution_count": 35,
34 | "metadata": {
35 | "collapsed": false
36 | },
37 | "outputs": [
38 | {
39 | "name": "stdout",
40 | "output_type": "stream",
41 | "text": [
42 | "T = 365 time steps\n",
43 | "V = 6197 features\n"
44 | ]
45 | }
46 | ],
47 | "source": [
48 | "Y_TV = data_dict['Y_TV'] # observed TxV count matrix\n",
49 | "(T, V) = Y_TV.shape\n",
50 | "print 'T = %d time steps' % T\n",
51 | "print 'V = %d features' % V"
52 | ]
53 | },
54 | {
55 | "cell_type": "code",
56 | "execution_count": 36,
57 | "metadata": {
58 | "collapsed": false
59 | },
60 | "outputs": [
61 | {
62 | "name": "stdout",
63 | "output_type": "stream",
64 | "text": [
65 | "First time step: 2003-01-01T00:00:00.000000000\n",
66 | "Last time step: 2003-12-31T00:00:00.000000000\n"
67 | ]
68 | }
69 | ],
70 | "source": [
71 | "dates_T = data_dict['dates_T'] # time steps are days in 2003\n",
72 | "print 'First time step: %s' % dates_T[0]\n",
73 | "print 'Last time step: %s' % dates_T[-1]"
74 | ]
75 | },
76 | {
77 | "cell_type": "code",
78 | "execution_count": 37,
79 | "metadata": {
80 | "collapsed": false
81 | },
82 | "outputs": [
83 | {
84 | "name": "stdout",
85 | "output_type": "stream",
86 | "text": [
87 | "Most active feature: Iraq--United States\n",
88 | "Least active feature: Brazil--Uganda\n"
89 | ]
90 | }
91 | ],
92 | "source": [
93 | "labels_V = data_dict['labels_V'] # features are undirected edges of countries \n",
94 | "print 'Most active feature: %s' % labels_V[0]\n",
95 | "print 'Least active feature: %s' % labels_V[-1]"
96 | ]
97 | },
98 | {
99 | "cell_type": "markdown",
100 | "metadata": {},
101 | "source": [
102 | "# Exploratory analysis"
103 | ]
104 | },
105 | {
106 | "cell_type": "code",
107 | "execution_count": 38,
108 | "metadata": {
109 | "collapsed": false
110 | },
111 | "outputs": [],
112 | "source": [
113 | "K = 100 # number of latent components\n",
114 | "gam = 75 # shrinkage parameter\n",
115 | "tau = 1 # concentration parameter\n",
116 | "eps = 0.1 # uninformative gamma parameter\n",
117 | "stationary = True # stationary variant of the model\n",
118 | "steady = True # use steady state approx. (only for stationary)\n",
119 | "shrink = True # use the shrinkage version\n",
120 | "binary = False # whether the data is binary (vs. counts)\n",
121 | "seed = 111111 # random seed (optional)\n",
122 | "\n",
123 | "model = PGDS(T=T, V=V, K=K, eps=eps, gam=gam, tau=tau,\n",
124 | " stationary=int(stationary), steady=int(steady),\n",
125 | " shrink=int(shrink), binary=int(binary), seed=seed)"
126 | ]
127 | },
128 | {
129 | "cell_type": "code",
130 | "execution_count": 39,
131 | "metadata": {
132 | "collapsed": false
133 | },
134 | "outputs": [],
135 | "source": [
136 | "num_itns = 1000 # number of Gibbs sampling iterations (the more the merrier)\n",
137 | "verbose = False # whether to print out state\n",
138 | "initialize = True # whether to initialize model randomly\n",
139 | "\n",
140 | "model.fit(data=Y_TV,\n",
141 | " num_itns=num_itns,\n",
142 | " verbose=verbose,\n",
143 | " initialize=initialize)"
144 | ]
145 | },
146 | {
147 | "cell_type": "code",
148 | "execution_count": 40,
149 | "metadata": {
150 | "collapsed": false
151 | },
152 | "outputs": [],
153 | "source": [
154 | "state = dict(model.get_state())\n",
155 | "Theta_TK = state['Theta_TK'] # TxK time step factors\n",
156 | "Phi_KV = state['Phi_KV'] # KxV feature factors\n",
157 | "Pi_KK = state['Pi_KK'] # KxK transition matrix\n",
158 | "nu_K = state['nu_K'] # K component weights"
159 | ]
160 | },
161 | {
162 | "cell_type": "code",
163 | "execution_count": 41,
164 | "metadata": {
165 | "collapsed": false
166 | },
167 | "outputs": [
168 | {
169 | "name": "stdout",
170 | "output_type": "stream",
171 | "text": [
172 | "['Iraq--United States' 'Iraq--United Kingdom'\n",
173 | " 'United Kingdom--United States' 'Turkey--United States'\n",
174 | " 'Russian Federation--United States' 'Iraq--Russian Federation'\n",
175 | " 'Iraq--Turkey' 'France--United States' 'Australia--United States'\n",
176 | " 'South Korea--United States']\n"
177 | ]
178 | }
179 | ],
180 | "source": [
181 | "top_k = nu_K.argmax() # most active component\n",
182 | "features = Phi_KV[top_k].argsort()[::-1][:10] # top 10 features in top k\n",
183 | "print labels_V[features]"
184 | ]
185 | },
186 | {
187 | "cell_type": "code",
188 | "execution_count": 42,
189 | "metadata": {
190 | "collapsed": false
191 | },
192 | "outputs": [
193 | {
194 | "data": {
195 | "application/javascript": [
196 | "/* Put everything inside the global mpl namespace */\n",
197 | "window.mpl = {};\n",
198 | "\n",
199 | "mpl.get_websocket_type = function() {\n",
200 | " if (typeof(WebSocket) !== 'undefined') {\n",
201 | " return WebSocket;\n",
202 | " } else if (typeof(MozWebSocket) !== 'undefined') {\n",
203 | " return MozWebSocket;\n",
204 | " } else {\n",
205 | " alert('Your browser does not have WebSocket support.' +\n",
206 | " 'Please try Chrome, Safari or Firefox ≥ 6. ' +\n",
207 | " 'Firefox 4 and 5 are also supported but you ' +\n",
208 | " 'have to enable WebSockets in about:config.');\n",
209 | " };\n",
210 | "}\n",
211 | "\n",
212 | "mpl.figure = function(figure_id, websocket, ondownload, parent_element) {\n",
213 | " this.id = figure_id;\n",
214 | "\n",
215 | " this.ws = websocket;\n",
216 | "\n",
217 | " this.supports_binary = (this.ws.binaryType != undefined);\n",
218 | "\n",
219 | " if (!this.supports_binary) {\n",
220 | " var warnings = document.getElementById(\"mpl-warnings\");\n",
221 | " if (warnings) {\n",
222 | " warnings.style.display = 'block';\n",
223 | " warnings.textContent = (\n",
224 | " \"This browser does not support binary websocket messages. \" +\n",
225 | " \"Performance may be slow.\");\n",
226 | " }\n",
227 | " }\n",
228 | "\n",
229 | " this.imageObj = new Image();\n",
230 | "\n",
231 | " this.context = undefined;\n",
232 | " this.message = undefined;\n",
233 | " this.canvas = undefined;\n",
234 | " this.rubberband_canvas = undefined;\n",
235 | " this.rubberband_context = undefined;\n",
236 | " this.format_dropdown = undefined;\n",
237 | "\n",
238 | " this.image_mode = 'full';\n",
239 | "\n",
240 | " this.root = $('
');\n",
241 | " this._root_extra_style(this.root)\n",
242 | " this.root.attr('style', 'display: inline-block');\n",
243 | "\n",
244 | " $(parent_element).append(this.root);\n",
245 | "\n",
246 | " this._init_header(this);\n",
247 | " this._init_canvas(this);\n",
248 | " this._init_toolbar(this);\n",
249 | "\n",
250 | " var fig = this;\n",
251 | "\n",
252 | " this.waiting = false;\n",
253 | "\n",
254 | " this.ws.onopen = function () {\n",
255 | " fig.send_message(\"supports_binary\", {value: fig.supports_binary});\n",
256 | " fig.send_message(\"send_image_mode\", {});\n",
257 | " fig.send_message(\"refresh\", {});\n",
258 | " }\n",
259 | "\n",
260 | " this.imageObj.onload = function() {\n",
261 | " if (fig.image_mode == 'full') {\n",
262 | " // Full images could contain transparency (where diff images\n",
263 | " // almost always do), so we need to clear the canvas so that\n",
264 | " // there is no ghosting.\n",
265 | " fig.context.clearRect(0, 0, fig.canvas.width, fig.canvas.height);\n",
266 | " }\n",
267 | " fig.context.drawImage(fig.imageObj, 0, 0);\n",
268 | " };\n",
269 | "\n",
270 | " this.imageObj.onunload = function() {\n",
271 | " this.ws.close();\n",
272 | " }\n",
273 | "\n",
274 | " this.ws.onmessage = this._make_on_message_function(this);\n",
275 | "\n",
276 | " this.ondownload = ondownload;\n",
277 | "}\n",
278 | "\n",
279 | "mpl.figure.prototype._init_header = function() {\n",
280 | " var titlebar = $(\n",
281 | " '');\n",
283 | " var titletext = $(\n",
284 | " '');\n",
286 | " titlebar.append(titletext)\n",
287 | " this.root.append(titlebar);\n",
288 | " this.header = titletext[0];\n",
289 | "}\n",
290 | "\n",
291 | "\n",
292 | "\n",
293 | "mpl.figure.prototype._canvas_extra_style = function(canvas_div) {\n",
294 | "\n",
295 | "}\n",
296 | "\n",
297 | "\n",
298 | "mpl.figure.prototype._root_extra_style = function(canvas_div) {\n",
299 | "\n",
300 | "}\n",
301 | "\n",
302 | "mpl.figure.prototype._init_canvas = function() {\n",
303 | " var fig = this;\n",
304 | "\n",
305 | " var canvas_div = $('');\n",
306 | "\n",
307 | " canvas_div.attr('style', 'position: relative; clear: both; outline: 0');\n",
308 | "\n",
309 | " function canvas_keyboard_event(event) {\n",
310 | " return fig.key_event(event, event['data']);\n",
311 | " }\n",
312 | "\n",
313 | " canvas_div.keydown('key_press', canvas_keyboard_event);\n",
314 | " canvas_div.keyup('key_release', canvas_keyboard_event);\n",
315 | " this.canvas_div = canvas_div\n",
316 | " this._canvas_extra_style(canvas_div)\n",
317 | " this.root.append(canvas_div);\n",
318 | "\n",
319 | " var canvas = $('');\n",
320 | " canvas.addClass('mpl-canvas');\n",
321 | " canvas.attr('style', \"left: 0; top: 0; z-index: 0; outline: 0\")\n",
322 | "\n",
323 | " this.canvas = canvas[0];\n",
324 | " this.context = canvas[0].getContext(\"2d\");\n",
325 | "\n",
326 | " var rubberband = $('');\n",
327 | " rubberband.attr('style', \"position: absolute; left: 0; top: 0; z-index: 1;\")\n",
328 | "\n",
329 | " var pass_mouse_events = true;\n",
330 | "\n",
331 | " canvas_div.resizable({\n",
332 | " start: function(event, ui) {\n",
333 | " pass_mouse_events = false;\n",
334 | " },\n",
335 | " resize: function(event, ui) {\n",
336 | " fig.request_resize(ui.size.width, ui.size.height);\n",
337 | " },\n",
338 | " stop: function(event, ui) {\n",
339 | " pass_mouse_events = true;\n",
340 | " fig.request_resize(ui.size.width, ui.size.height);\n",
341 | " },\n",
342 | " });\n",
343 | "\n",
344 | " function mouse_event_fn(event) {\n",
345 | " if (pass_mouse_events)\n",
346 | " return fig.mouse_event(event, event['data']);\n",
347 | " }\n",
348 | "\n",
349 | " rubberband.mousedown('button_press', mouse_event_fn);\n",
350 | " rubberband.mouseup('button_release', mouse_event_fn);\n",
351 | " // Throttle sequential mouse events to 1 every 20ms.\n",
352 | " rubberband.mousemove('motion_notify', mouse_event_fn);\n",
353 | "\n",
354 | " rubberband.mouseenter('figure_enter', mouse_event_fn);\n",
355 | " rubberband.mouseleave('figure_leave', mouse_event_fn);\n",
356 | "\n",
357 | " canvas_div.on(\"wheel\", function (event) {\n",
358 | " event = event.originalEvent;\n",
359 | " event['data'] = 'scroll'\n",
360 | " if (event.deltaY < 0) {\n",
361 | " event.step = 1;\n",
362 | " } else {\n",
363 | " event.step = -1;\n",
364 | " }\n",
365 | " mouse_event_fn(event);\n",
366 | " });\n",
367 | "\n",
368 | " canvas_div.append(canvas);\n",
369 | " canvas_div.append(rubberband);\n",
370 | "\n",
371 | " this.rubberband = rubberband;\n",
372 | " this.rubberband_canvas = rubberband[0];\n",
373 | " this.rubberband_context = rubberband[0].getContext(\"2d\");\n",
374 | " this.rubberband_context.strokeStyle = \"#000000\";\n",
375 | "\n",
376 | " this._resize_canvas = function(width, height) {\n",
377 | " // Keep the size of the canvas, canvas container, and rubber band\n",
378 | " // canvas in synch.\n",
379 | " canvas_div.css('width', width)\n",
380 | " canvas_div.css('height', height)\n",
381 | "\n",
382 | " canvas.attr('width', width);\n",
383 | " canvas.attr('height', height);\n",
384 | "\n",
385 | " rubberband.attr('width', width);\n",
386 | " rubberband.attr('height', height);\n",
387 | " }\n",
388 | "\n",
389 | " // Set the figure to an initial 600x600px, this will subsequently be updated\n",
390 | " // upon first draw.\n",
391 | " this._resize_canvas(600, 600);\n",
392 | "\n",
393 | " // Disable right mouse context menu.\n",
394 | " $(this.rubberband_canvas).bind(\"contextmenu\",function(e){\n",
395 | " return false;\n",
396 | " });\n",
397 | "\n",
398 | " function set_focus () {\n",
399 | " canvas.focus();\n",
400 | " canvas_div.focus();\n",
401 | " }\n",
402 | "\n",
403 | " window.setTimeout(set_focus, 100);\n",
404 | "}\n",
405 | "\n",
406 | "mpl.figure.prototype._init_toolbar = function() {\n",
407 | " var fig = this;\n",
408 | "\n",
409 | " var nav_element = $('')\n",
410 | " nav_element.attr('style', 'width: 100%');\n",
411 | " this.root.append(nav_element);\n",
412 | "\n",
413 | " // Define a callback function for later on.\n",
414 | " function toolbar_event(event) {\n",
415 | " return fig.toolbar_button_onclick(event['data']);\n",
416 | " }\n",
417 | " function toolbar_mouse_event(event) {\n",
418 | " return fig.toolbar_button_onmouseover(event['data']);\n",
419 | " }\n",
420 | "\n",
421 | " for(var toolbar_ind in mpl.toolbar_items) {\n",
422 | " var name = mpl.toolbar_items[toolbar_ind][0];\n",
423 | " var tooltip = mpl.toolbar_items[toolbar_ind][1];\n",
424 | " var image = mpl.toolbar_items[toolbar_ind][2];\n",
425 | " var method_name = mpl.toolbar_items[toolbar_ind][3];\n",
426 | "\n",
427 | " if (!name) {\n",
428 | " // put a spacer in here.\n",
429 | " continue;\n",
430 | " }\n",
431 | " var button = $('');\n",
432 | " button.addClass('ui-button ui-widget ui-state-default ui-corner-all ' +\n",
433 | " 'ui-button-icon-only');\n",
434 | " button.attr('role', 'button');\n",
435 | " button.attr('aria-disabled', 'false');\n",
436 | " button.click(method_name, toolbar_event);\n",
437 | " button.mouseover(tooltip, toolbar_mouse_event);\n",
438 | "\n",
439 | " var icon_img = $('');\n",
440 | " icon_img.addClass('ui-button-icon-primary ui-icon');\n",
441 | " icon_img.addClass(image);\n",
442 | " icon_img.addClass('ui-corner-all');\n",
443 | "\n",
444 | " var tooltip_span = $('');\n",
445 | " tooltip_span.addClass('ui-button-text');\n",
446 | " tooltip_span.html(tooltip);\n",
447 | "\n",
448 | " button.append(icon_img);\n",
449 | " button.append(tooltip_span);\n",
450 | "\n",
451 | " nav_element.append(button);\n",
452 | " }\n",
453 | "\n",
454 | " var fmt_picker_span = $('');\n",
455 | "\n",
456 | " var fmt_picker = $('');\n",
457 | " fmt_picker.addClass('mpl-toolbar-option ui-widget ui-widget-content');\n",
458 | " fmt_picker_span.append(fmt_picker);\n",
459 | " nav_element.append(fmt_picker_span);\n",
460 | " this.format_dropdown = fmt_picker[0];\n",
461 | "\n",
462 | " for (var ind in mpl.extensions) {\n",
463 | " var fmt = mpl.extensions[ind];\n",
464 | " var option = $(\n",
465 | " '', {selected: fmt === mpl.default_extension}).html(fmt);\n",
466 | " fmt_picker.append(option)\n",
467 | " }\n",
468 | "\n",
469 | " // Add hover states to the ui-buttons\n",
470 | " $( \".ui-button\" ).hover(\n",
471 | " function() { $(this).addClass(\"ui-state-hover\");},\n",
472 | " function() { $(this).removeClass(\"ui-state-hover\");}\n",
473 | " );\n",
474 | "\n",
475 | " var status_bar = $('');\n",
476 | " nav_element.append(status_bar);\n",
477 | " this.message = status_bar[0];\n",
478 | "}\n",
479 | "\n",
480 | "mpl.figure.prototype.request_resize = function(x_pixels, y_pixels) {\n",
481 | " // Request matplotlib to resize the figure. Matplotlib will then trigger a resize in the client,\n",
482 | " // which will in turn request a refresh of the image.\n",
483 | " this.send_message('resize', {'width': x_pixels, 'height': y_pixels});\n",
484 | "}\n",
485 | "\n",
486 | "mpl.figure.prototype.send_message = function(type, properties) {\n",
487 | " properties['type'] = type;\n",
488 | " properties['figure_id'] = this.id;\n",
489 | " this.ws.send(JSON.stringify(properties));\n",
490 | "}\n",
491 | "\n",
492 | "mpl.figure.prototype.send_draw_message = function() {\n",
493 | " if (!this.waiting) {\n",
494 | " this.waiting = true;\n",
495 | " this.ws.send(JSON.stringify({type: \"draw\", figure_id: this.id}));\n",
496 | " }\n",
497 | "}\n",
498 | "\n",
499 | "\n",
500 | "mpl.figure.prototype.handle_save = function(fig, msg) {\n",
501 | " var format_dropdown = fig.format_dropdown;\n",
502 | " var format = format_dropdown.options[format_dropdown.selectedIndex].value;\n",
503 | " fig.ondownload(fig, format);\n",
504 | "}\n",
505 | "\n",
506 | "\n",
507 | "mpl.figure.prototype.handle_resize = function(fig, msg) {\n",
508 | " var size = msg['size'];\n",
509 | " if (size[0] != fig.canvas.width || size[1] != fig.canvas.height) {\n",
510 | " fig._resize_canvas(size[0], size[1]);\n",
511 | " fig.send_message(\"refresh\", {});\n",
512 | " };\n",
513 | "}\n",
514 | "\n",
515 | "mpl.figure.prototype.handle_rubberband = function(fig, msg) {\n",
516 | " var x0 = msg['x0'];\n",
517 | " var y0 = fig.canvas.height - msg['y0'];\n",
518 | " var x1 = msg['x1'];\n",
519 | " var y1 = fig.canvas.height - msg['y1'];\n",
520 | " x0 = Math.floor(x0) + 0.5;\n",
521 | " y0 = Math.floor(y0) + 0.5;\n",
522 | " x1 = Math.floor(x1) + 0.5;\n",
523 | " y1 = Math.floor(y1) + 0.5;\n",
524 | " var min_x = Math.min(x0, x1);\n",
525 | " var min_y = Math.min(y0, y1);\n",
526 | " var width = Math.abs(x1 - x0);\n",
527 | " var height = Math.abs(y1 - y0);\n",
528 | "\n",
529 | " fig.rubberband_context.clearRect(\n",
530 | " 0, 0, fig.canvas.width, fig.canvas.height);\n",
531 | "\n",
532 | " fig.rubberband_context.strokeRect(min_x, min_y, width, height);\n",
533 | "}\n",
534 | "\n",
535 | "mpl.figure.prototype.handle_figure_label = function(fig, msg) {\n",
536 | " // Updates the figure title.\n",
537 | " fig.header.textContent = msg['label'];\n",
538 | "}\n",
539 | "\n",
540 | "mpl.figure.prototype.handle_cursor = function(fig, msg) {\n",
541 | " var cursor = msg['cursor'];\n",
542 | " switch(cursor)\n",
543 | " {\n",
544 | " case 0:\n",
545 | " cursor = 'pointer';\n",
546 | " break;\n",
547 | " case 1:\n",
548 | " cursor = 'default';\n",
549 | " break;\n",
550 | " case 2:\n",
551 | " cursor = 'crosshair';\n",
552 | " break;\n",
553 | " case 3:\n",
554 | " cursor = 'move';\n",
555 | " break;\n",
556 | " }\n",
557 | " fig.rubberband_canvas.style.cursor = cursor;\n",
558 | "}\n",
559 | "\n",
560 | "mpl.figure.prototype.handle_message = function(fig, msg) {\n",
561 | " fig.message.textContent = msg['message'];\n",
562 | "}\n",
563 | "\n",
564 | "mpl.figure.prototype.handle_draw = function(fig, msg) {\n",
565 | " // Request the server to send over a new figure.\n",
566 | " fig.send_draw_message();\n",
567 | "}\n",
568 | "\n",
569 | "mpl.figure.prototype.handle_image_mode = function(fig, msg) {\n",
570 | " fig.image_mode = msg['mode'];\n",
571 | "}\n",
572 | "\n",
573 | "mpl.figure.prototype.updated_canvas_event = function() {\n",
574 | " // Called whenever the canvas gets updated.\n",
575 | " this.send_message(\"ack\", {});\n",
576 | "}\n",
577 | "\n",
578 | "// A function to construct a web socket function for onmessage handling.\n",
579 | "// Called in the figure constructor.\n",
580 | "mpl.figure.prototype._make_on_message_function = function(fig) {\n",
581 | " return function socket_on_message(evt) {\n",
582 | " if (evt.data instanceof Blob) {\n",
583 | " /* FIXME: We get \"Resource interpreted as Image but\n",
584 | " * transferred with MIME type text/plain:\" errors on\n",
585 | " * Chrome. But how to set the MIME type? It doesn't seem\n",
586 | " * to be part of the websocket stream */\n",
587 | " evt.data.type = \"image/png\";\n",
588 | "\n",
589 | " /* Free the memory for the previous frames */\n",
590 | " if (fig.imageObj.src) {\n",
591 | " (window.URL || window.webkitURL).revokeObjectURL(\n",
592 | " fig.imageObj.src);\n",
593 | " }\n",
594 | "\n",
595 | " fig.imageObj.src = (window.URL || window.webkitURL).createObjectURL(\n",
596 | " evt.data);\n",
597 | " fig.updated_canvas_event();\n",
598 | " fig.waiting = false;\n",
599 | " return;\n",
600 | " }\n",
601 | " else if (typeof evt.data === 'string' && evt.data.slice(0, 21) == \"data:image/png;base64\") {\n",
602 | " fig.imageObj.src = evt.data;\n",
603 | " fig.updated_canvas_event();\n",
604 | " fig.waiting = false;\n",
605 | " return;\n",
606 | " }\n",
607 | "\n",
608 | " var msg = JSON.parse(evt.data);\n",
609 | " var msg_type = msg['type'];\n",
610 | "\n",
611 | " // Call the \"handle_{type}\" callback, which takes\n",
612 | " // the figure and JSON message as its only arguments.\n",
613 | " try {\n",
614 | " var callback = fig[\"handle_\" + msg_type];\n",
615 | " } catch (e) {\n",
616 | " console.log(\"No handler for the '\" + msg_type + \"' message type: \", msg);\n",
617 | " return;\n",
618 | " }\n",
619 | "\n",
620 | " if (callback) {\n",
621 | " try {\n",
622 | " // console.log(\"Handling '\" + msg_type + \"' message: \", msg);\n",
623 | " callback(fig, msg);\n",
624 | " } catch (e) {\n",
625 | " console.log(\"Exception inside the 'handler_\" + msg_type + \"' callback:\", e, e.stack, msg);\n",
626 | " }\n",
627 | " }\n",
628 | " };\n",
629 | "}\n",
630 | "\n",
631 | "// from http://stackoverflow.com/questions/1114465/getting-mouse-location-in-canvas\n",
632 | "mpl.findpos = function(e) {\n",
633 | " //this section is from http://www.quirksmode.org/js/events_properties.html\n",
634 | " var targ;\n",
635 | " if (!e)\n",
636 | " e = window.event;\n",
637 | " if (e.target)\n",
638 | " targ = e.target;\n",
639 | " else if (e.srcElement)\n",
640 | " targ = e.srcElement;\n",
641 | " if (targ.nodeType == 3) // defeat Safari bug\n",
642 | " targ = targ.parentNode;\n",
643 | "\n",
644 | " // jQuery normalizes the pageX and pageY\n",
645 | " // pageX,Y are the mouse positions relative to the document\n",
646 | " // offset() returns the position of the element relative to the document\n",
647 | " var x = e.pageX - $(targ).offset().left;\n",
648 | " var y = e.pageY - $(targ).offset().top;\n",
649 | "\n",
650 | " return {\"x\": x, \"y\": y};\n",
651 | "};\n",
652 | "\n",
653 | "/*\n",
654 | " * return a copy of an object with only non-object keys\n",
655 | " * we need this to avoid circular references\n",
656 | " * http://stackoverflow.com/a/24161582/3208463\n",
657 | " */\n",
658 | "function simpleKeys (original) {\n",
659 | " return Object.keys(original).reduce(function (obj, key) {\n",
660 | " if (typeof original[key] !== 'object')\n",
661 | " obj[key] = original[key]\n",
662 | " return obj;\n",
663 | " }, {});\n",
664 | "}\n",
665 | "\n",
666 | "mpl.figure.prototype.mouse_event = function(event, name) {\n",
667 | " var canvas_pos = mpl.findpos(event)\n",
668 | "\n",
669 | " if (name === 'button_press')\n",
670 | " {\n",
671 | " this.canvas.focus();\n",
672 | " this.canvas_div.focus();\n",
673 | " }\n",
674 | "\n",
675 | " var x = canvas_pos.x;\n",
676 | " var y = canvas_pos.y;\n",
677 | "\n",
678 | " this.send_message(name, {x: x, y: y, button: event.button,\n",
679 | " step: event.step,\n",
680 | " guiEvent: simpleKeys(event)});\n",
681 | "\n",
682 | " /* This prevents the web browser from automatically changing to\n",
683 | " * the text insertion cursor when the button is pressed. We want\n",
684 | " * to control all of the cursor setting manually through the\n",
685 | " * 'cursor' event from matplotlib */\n",
686 | " event.preventDefault();\n",
687 | " return false;\n",
688 | "}\n",
689 | "\n",
690 | "mpl.figure.prototype._key_event_extra = function(event, name) {\n",
691 | " // Handle any extra behaviour associated with a key event\n",
692 | "}\n",
693 | "\n",
694 | "mpl.figure.prototype.key_event = function(event, name) {\n",
695 | "\n",
696 | " // Prevent repeat events\n",
697 | " if (name == 'key_press')\n",
698 | " {\n",
699 | " if (event.which === this._key)\n",
700 | " return;\n",
701 | " else\n",
702 | " this._key = event.which;\n",
703 | " }\n",
704 | " if (name == 'key_release')\n",
705 | " this._key = null;\n",
706 | "\n",
707 | " var value = '';\n",
708 | " if (event.ctrlKey && event.which != 17)\n",
709 | " value += \"ctrl+\";\n",
710 | " if (event.altKey && event.which != 18)\n",
711 | " value += \"alt+\";\n",
712 | " if (event.shiftKey && event.which != 16)\n",
713 | " value += \"shift+\";\n",
714 | "\n",
715 | " value += 'k';\n",
716 | " value += event.which.toString();\n",
717 | "\n",
718 | " this._key_event_extra(event, name);\n",
719 | "\n",
720 | " this.send_message(name, {key: value,\n",
721 | " guiEvent: simpleKeys(event)});\n",
722 | " return false;\n",
723 | "}\n",
724 | "\n",
725 | "mpl.figure.prototype.toolbar_button_onclick = function(name) {\n",
726 | " if (name == 'download') {\n",
727 | " this.handle_save(this, null);\n",
728 | " } else {\n",
729 | " this.send_message(\"toolbar_button\", {name: name});\n",
730 | " }\n",
731 | "};\n",
732 | "\n",
733 | "mpl.figure.prototype.toolbar_button_onmouseover = function(tooltip) {\n",
734 | " this.message.textContent = tooltip;\n",
735 | "};\n",
736 | "mpl.toolbar_items = [[\"Home\", \"Reset original view\", \"fa fa-home icon-home\", \"home\"], [\"Back\", \"Back to previous view\", \"fa fa-arrow-left icon-arrow-left\", \"back\"], [\"Forward\", \"Forward to next view\", \"fa fa-arrow-right icon-arrow-right\", \"forward\"], [\"\", \"\", \"\", \"\"], [\"Pan\", \"Pan axes with left mouse, zoom with right\", \"fa fa-arrows icon-move\", \"pan\"], [\"Zoom\", \"Zoom to rectangle\", \"fa fa-square-o icon-check-empty\", \"zoom\"], [\"\", \"\", \"\", \"\"], [\"Download\", \"Download plot\", \"fa fa-floppy-o icon-save\", \"download\"]];\n",
737 | "\n",
738 | "mpl.extensions = [\"eps\", \"jpeg\", \"pdf\", \"png\", \"ps\", \"raw\", \"svg\", \"tif\"];\n",
739 | "\n",
740 | "mpl.default_extension = \"png\";var comm_websocket_adapter = function(comm) {\n",
741 | " // Create a \"websocket\"-like object which calls the given IPython comm\n",
742 | " // object with the appropriate methods. Currently this is a non binary\n",
743 | " // socket, so there is still some room for performance tuning.\n",
744 | " var ws = {};\n",
745 | "\n",
746 | " ws.close = function() {\n",
747 | " comm.close()\n",
748 | " };\n",
749 | " ws.send = function(m) {\n",
750 | " //console.log('sending', m);\n",
751 | " comm.send(m);\n",
752 | " };\n",
753 | " // Register the callback with on_msg.\n",
754 | " comm.on_msg(function(msg) {\n",
755 | " //console.log('receiving', msg['content']['data'], msg);\n",
756 | " // Pass the mpl event to the overriden (by mpl) onmessage function.\n",
757 | " ws.onmessage(msg['content']['data'])\n",
758 | " });\n",
759 | " return ws;\n",
760 | "}\n",
761 | "\n",
762 | "mpl.mpl_figure_comm = function(comm, msg) {\n",
763 | " // This is the function which gets called when the mpl process\n",
764 | " // starts-up an IPython Comm through the \"matplotlib\" channel.\n",
765 | "\n",
766 | " var id = msg.content.data.id;\n",
767 | " // Get hold of the div created by the display call when the Comm\n",
768 | " // socket was opened in Python.\n",
769 | " var element = $(\"#\" + id);\n",
770 | " var ws_proxy = comm_websocket_adapter(comm)\n",
771 | "\n",
772 | " function ondownload(figure, format) {\n",
773 | " window.open(figure.imageObj.src);\n",
774 | " }\n",
775 | "\n",
776 | " var fig = new mpl.figure(id, ws_proxy,\n",
777 | " ondownload,\n",
778 | " element.get(0));\n",
779 | "\n",
780 | " // Call onopen now - mpl needs it, as it is assuming we've passed it a real\n",
781 | " // web socket which is closed, not our websocket->open comm proxy.\n",
782 | " ws_proxy.onopen();\n",
783 | "\n",
784 | " fig.parent_element = element.get(0);\n",
785 | " fig.cell_info = mpl.find_output_cell(\"\");\n",
786 | " if (!fig.cell_info) {\n",
787 | " console.error(\"Failed to find cell for figure\", id, fig);\n",
788 | " return;\n",
789 | " }\n",
790 | "\n",
791 | " var output_index = fig.cell_info[2]\n",
792 | " var cell = fig.cell_info[0];\n",
793 | "\n",
794 | "};\n",
795 | "\n",
796 | "mpl.figure.prototype.handle_close = function(fig, msg) {\n",
797 | " fig.root.unbind('remove')\n",
798 | "\n",
799 | " // Update the output cell to use the data from the current canvas.\n",
800 | " fig.push_to_output();\n",
801 | " var dataURL = fig.canvas.toDataURL();\n",
802 | " // Re-enable the keyboard manager in IPython - without this line, in FF,\n",
803 | " // the notebook keyboard shortcuts fail.\n",
804 | " IPython.keyboard_manager.enable()\n",
805 | " $(fig.parent_element).html('
');\n",
806 | " fig.close_ws(fig, msg);\n",
807 | "}\n",
808 | "\n",
809 | "mpl.figure.prototype.close_ws = function(fig, msg){\n",
810 | " fig.send_message('closing', msg);\n",
811 | " // fig.ws.close()\n",
812 | "}\n",
813 | "\n",
814 | "mpl.figure.prototype.push_to_output = function(remove_interactive) {\n",
815 | " // Turn the data on the canvas into data in the output cell.\n",
816 | " var dataURL = this.canvas.toDataURL();\n",
817 | " this.cell_info[1]['text/html'] = '
';\n",
818 | "}\n",
819 | "\n",
820 | "mpl.figure.prototype.updated_canvas_event = function() {\n",
821 | " // Tell IPython that the notebook contents must change.\n",
822 | " IPython.notebook.set_dirty(true);\n",
823 | " this.send_message(\"ack\", {});\n",
824 | " var fig = this;\n",
825 | " // Wait a second, then push the new image to the DOM so\n",
826 | " // that it is saved nicely (might be nice to debounce this).\n",
827 | " setTimeout(function () { fig.push_to_output() }, 1000);\n",
828 | "}\n",
829 | "\n",
830 | "mpl.figure.prototype._init_toolbar = function() {\n",
831 | " var fig = this;\n",
832 | "\n",
833 | " var nav_element = $('')\n",
834 | " nav_element.attr('style', 'width: 100%');\n",
835 | " this.root.append(nav_element);\n",
836 | "\n",
837 | " // Define a callback function for later on.\n",
838 | " function toolbar_event(event) {\n",
839 | " return fig.toolbar_button_onclick(event['data']);\n",
840 | " }\n",
841 | " function toolbar_mouse_event(event) {\n",
842 | " return fig.toolbar_button_onmouseover(event['data']);\n",
843 | " }\n",
844 | "\n",
845 | " for(var toolbar_ind in mpl.toolbar_items){\n",
846 | " var name = mpl.toolbar_items[toolbar_ind][0];\n",
847 | " var tooltip = mpl.toolbar_items[toolbar_ind][1];\n",
848 | " var image = mpl.toolbar_items[toolbar_ind][2];\n",
849 | " var method_name = mpl.toolbar_items[toolbar_ind][3];\n",
850 | "\n",
851 | " if (!name) { continue; };\n",
852 | "\n",
853 | " var button = $('');\n",
854 | " button.click(method_name, toolbar_event);\n",
855 | " button.mouseover(tooltip, toolbar_mouse_event);\n",
856 | " nav_element.append(button);\n",
857 | " }\n",
858 | "\n",
859 | " // Add the status bar.\n",
860 | " var status_bar = $('');\n",
861 | " nav_element.append(status_bar);\n",
862 | " this.message = status_bar[0];\n",
863 | "\n",
864 | " // Add the close button to the window.\n",
865 | " var buttongrp = $('');\n",
866 | " var button = $('');\n",
867 | " button.click(function (evt) { fig.handle_close(fig, {}); } );\n",
868 | " button.mouseover('Stop Interaction', toolbar_mouse_event);\n",
869 | " buttongrp.append(button);\n",
870 | " var titlebar = this.root.find($('.ui-dialog-titlebar'));\n",
871 | " titlebar.prepend(buttongrp);\n",
872 | "}\n",
873 | "\n",
874 | "mpl.figure.prototype._root_extra_style = function(el){\n",
875 | " var fig = this\n",
876 | " el.on(\"remove\", function(){\n",
877 | "\tfig.close_ws(fig, {});\n",
878 | " });\n",
879 | "}\n",
880 | "\n",
881 | "mpl.figure.prototype._canvas_extra_style = function(el){\n",
882 | " // this is important to make the div 'focusable\n",
883 | " el.attr('tabindex', 0)\n",
884 | " // reach out to IPython and tell the keyboard manager to turn it's self\n",
885 | " // off when our div gets focus\n",
886 | "\n",
887 | " // location in version 3\n",
888 | " if (IPython.notebook.keyboard_manager) {\n",
889 | " IPython.notebook.keyboard_manager.register_events(el);\n",
890 | " }\n",
891 | " else {\n",
892 | " // location in version 2\n",
893 | " IPython.keyboard_manager.register_events(el);\n",
894 | " }\n",
895 | "\n",
896 | "}\n",
897 | "\n",
898 | "mpl.figure.prototype._key_event_extra = function(event, name) {\n",
899 | " var manager = IPython.notebook.keyboard_manager;\n",
900 | " if (!manager)\n",
901 | " manager = IPython.keyboard_manager;\n",
902 | "\n",
903 | " // Check for shift+enter\n",
904 | " if (event.shiftKey && event.which == 13) {\n",
905 | " this.canvas_div.blur();\n",
906 | " event.shiftKey = false;\n",
907 | " // Send a \"J\" for go to next cell\n",
908 | " event.which = 74;\n",
909 | " event.keyCode = 74;\n",
910 | " manager.command_mode();\n",
911 | " manager.handle_keydown(event);\n",
912 | " }\n",
913 | "}\n",
914 | "\n",
915 | "mpl.figure.prototype.handle_save = function(fig, msg) {\n",
916 | " fig.ondownload(fig, null);\n",
917 | "}\n",
918 | "\n",
919 | "\n",
920 | "mpl.find_output_cell = function(html_output) {\n",
921 | " // Return the cell and output element which can be found *uniquely* in the notebook.\n",
922 | " // Note - this is a bit hacky, but it is done because the \"notebook_saving.Notebook\"\n",
923 | " // IPython event is triggered only after the cells have been serialised, which for\n",
924 | " // our purposes (turning an active figure into a static one), is too late.\n",
925 | " var cells = IPython.notebook.get_cells();\n",
926 | " var ncells = cells.length;\n",
927 | " for (var i=0; i= 3 moved mimebundle to data attribute of output\n",
934 | " data = data.data;\n",
935 | " }\n",
936 | " if (data['text/html'] == html_output) {\n",
937 | " return [cell, data, j];\n",
938 | " }\n",
939 | " }\n",
940 | " }\n",
941 | " }\n",
942 | "}\n",
943 | "\n",
944 | "// Register the function which deals with the matplotlib target/channel.\n",
945 | "// The kernel may be null if the page has been refreshed.\n",
946 | "if (IPython.notebook.kernel != null) {\n",
947 | " IPython.notebook.kernel.comm_manager.register_target('matplotlib', mpl.mpl_figure_comm);\n",
948 | "}\n"
949 | ],
950 | "text/plain": [
951 | ""
952 | ]
953 | },
954 | "metadata": {},
955 | "output_type": "display_data"
956 | },
957 | {
958 | "data": {
959 | "text/html": [
960 | "
"
961 | ],
962 | "text/plain": [
963 | ""
964 | ]
965 | },
966 | "metadata": {},
967 | "output_type": "display_data"
968 | }
969 | ],
970 | "source": [
971 | "k = top_k # which component to visualize (replace with any k = 1...K to further explore)\n",
972 | "\n",
973 | "fig = plt.figure()\n",
974 | "ax1 = fig.add_subplot(211)\n",
975 | "ax2 = fig.add_subplot(212)\n",
976 | "\n",
977 | "# Plot time-step factors for component k\n",
978 | "ax1.plot(Theta_TK[:, k])\n",
979 | "xticks = np.linspace(10, Theta_TK.shape[0] - 10, num=8, dtype=int)\n",
980 | "ax1.set_xticks(xticks)\n",
981 | "formatted_dates = np.asarray([pd.Timestamp(x).strftime('%b %Y') for x in dates_T])\n",
982 | "ax1.set_xticklabels(formatted_dates[xticks], weight='bold')\n",
983 | "\n",
984 | "# Plot top N feature factors for component k\n",
985 | "N = 15\n",
986 | "top_N_features = Phi_KV[k].argsort()[::-1][:N]\n",
987 | "ax2.stem(Phi_KV[k, top_N_features])\n",
988 | "ax2.set_xticks(range(N))\n",
989 | "ax2.set_xticklabels(labels_V[top_N_features], rotation=330, weight='bold', ha='left')\n",
990 | "\n",
991 | "plt.show()"
992 | ]
993 | },
994 | {
995 | "cell_type": "code",
996 | "execution_count": 43,
997 | "metadata": {
998 | "collapsed": false
999 | },
1000 | "outputs": [
1001 | {
1002 | "name": "stdout",
1003 | "output_type": "stream",
1004 | "text": [
1005 | "Top date: March 20 2003\n",
1006 | "Top N features: ['Iraq--United States' 'Iraq--United Kingdom'\n",
1007 | " 'United Kingdom--United States']\n",
1008 | "Example Google query: 'March 20 2003 Iraq United States'\n"
1009 | ]
1010 | }
1011 | ],
1012 | "source": [
1013 | "# Pro tip: To find interpret the components, Google the top date and top feature(s).\n",
1014 | "\n",
1015 | "k = top_k # which component to explore\n",
1016 | "\n",
1017 | "top_t = Theta_TK[:, k].argmax() # top time-step factor\n",
1018 | "top_date = dates_T[top_t] # date of top time-step factor (top date)\n",
1019 | "formatted_date = pd.Timestamp(top_date).strftime('%B %d %Y')\n",
1020 | "print 'Top date: %s' % formatted_date\n",
1021 | "\n",
1022 | "N = 3\n",
1023 | "top_vs = Phi_KV[k, :].argsort()[::-1][:N] # top N features\n",
1024 | "top_features = labels_V[top_vs]\n",
1025 | "print 'Top N features: %s' % top_features\n",
1026 | "\n",
1027 | "query = formatted_date + ' ' + top_features[0].replace('--', ' ')\n",
1028 | "print \"Example Google query: '%s'\" % query\n"
1029 | ]
1030 | },
1031 | {
1032 | "cell_type": "markdown",
1033 | "metadata": {},
1034 | "source": [
1035 | "# Predictive analysis"
1036 | ]
1037 | },
1038 | {
1039 | "cell_type": "code",
1040 | "execution_count": 44,
1041 | "metadata": {
1042 | "collapsed": false
1043 | },
1044 | "outputs": [],
1045 | "source": [
1046 | "from run_pgds import mre, mae"
1047 | ]
1048 | },
1049 | {
1050 | "cell_type": "code",
1051 | "execution_count": 45,
1052 | "metadata": {
1053 | "collapsed": true
1054 | },
1055 | "outputs": [],
1056 | "source": [
1057 | "S = 2 # how many final time-steps to hold out completely (for forecasting)\n",
1058 | "\n",
1059 | "data_SV = Y_TV[-S:] # hold out final S time steps (to forecast)\n",
1060 | "train_TV = Y_TV[:-S] # training data matrix is everything up to t=S\n",
1061 | "(T, V) = train_TV.shape \n",
1062 | "\n",
1063 | "percent = 0.1 # percent of training data to mask (for smoothing)\n",
1064 | "mask_TV = (rn.random(size=(T, V)) < percent).astype(bool) # create a TxV random mask \n",
1065 | "\n",
1066 | "masked_train_TV = np.ma.array(train_TV, mask=mask_TV)"
1067 | ]
1068 | },
1069 | {
1070 | "cell_type": "code",
1071 | "execution_count": 46,
1072 | "metadata": {
1073 | "collapsed": true
1074 | },
1075 | "outputs": [],
1076 | "source": [
1077 | "K = 100 # number of latent components\n",
1078 | "gam = 75 # shrinkage parameter\n",
1079 | "tau = 1 # concentration parameter\n",
1080 | "eps = 0.1 # uninformative gamma parameter\n",
1081 | "stationary = True # stationary variant of the model\n",
1082 | "steady = True # use steady state approx. (only for stationary)\n",
1083 | "shrink = True # use the shrinkage version\n",
1084 | "binary = False # whether the data is binary (vs. counts)\n",
1085 | "seed = 111111 # random seed (optional)\n",
1086 | "\n",
1087 | "model = PGDS(T=T, V=V, K=K, eps=eps, gam=gam, tau=tau,\n",
1088 | " stationary=int(stationary), steady=int(steady),\n",
1089 | " shrink=int(shrink), binary=int(binary), seed=seed)"
1090 | ]
1091 | },
1092 | {
1093 | "cell_type": "code",
1094 | "execution_count": 47,
1095 | "metadata": {
1096 | "collapsed": false
1097 | },
1098 | "outputs": [],
1099 | "source": [
1100 | "num_itns = 100 # number of Gibbs sampling iterations (the more the merrier)\n",
1101 | "verbose = False # whether to print out state\n",
1102 | "initialize = True # whether to initialize model randomly\n",
1103 | "\n",
1104 | "model.fit(data=masked_train_TV,\n",
1105 | " num_itns=num_itns,\n",
1106 | " verbose=verbose,\n",
1107 | " initialize=initialize)"
1108 | ]
1109 | },
1110 | {
1111 | "cell_type": "code",
1112 | "execution_count": 48,
1113 | "metadata": {
1114 | "collapsed": false
1115 | },
1116 | "outputs": [
1117 | {
1118 | "name": "stdout",
1119 | "output_type": "stream",
1120 | "text": [
1121 | "Mean absolute error on reconstructing observed data: 0.274571\n",
1122 | "Mean relative error on reconstructing observed data: 0.184305\n",
1123 | "Mean absolute error on smoothing missing data: 0.436285\n",
1124 | "Mean relative error on smoothing missing data: 0.308913\n"
1125 | ]
1126 | }
1127 | ],
1128 | "source": [
1129 | "pred_TV = model.reconstruct() # reconstruction of the training matrix\n",
1130 | "\n",
1131 | "train_mae = mae(train_TV[~mask_TV], pred_TV[~mask_TV])\n",
1132 | "train_mre = mre(train_TV[~mask_TV], pred_TV[~mask_TV])\n",
1133 | "\n",
1134 | "smooth_mae = mae(train_TV[mask_TV], pred_TV[mask_TV])\n",
1135 | "smooth_mre = mre(train_TV[mask_TV], pred_TV[mask_TV])\n",
1136 | "\n",
1137 | "print 'Mean absolute error on reconstructing observed data: %f' % train_mae\n",
1138 | "print 'Mean relative error on reconstructing observed data: %f' % train_mre\n",
1139 | "print 'Mean absolute error on smoothing missing data: %f' % smooth_mae\n",
1140 | "print 'Mean relative error on smoothing missing data: %f' % smooth_mre"
1141 | ]
1142 | },
1143 | {
1144 | "cell_type": "code",
1145 | "execution_count": 49,
1146 | "metadata": {
1147 | "collapsed": false
1148 | },
1149 | "outputs": [
1150 | {
1151 | "name": "stdout",
1152 | "output_type": "stream",
1153 | "text": [
1154 | "Mean absolute error on forecasting: 0.324120\n",
1155 | "Mean relative error on forecasting: 0.241739\n"
1156 | ]
1157 | }
1158 | ],
1159 | "source": [
1160 | "pred_SV = model.forecast(n_timesteps=S) # forecast of next S time steps\n",
1161 | "\n",
1162 | "forecast_mae = mae(data_SV, pred_SV)\n",
1163 | "forecast_mre = mre(data_SV, pred_SV)\n",
1164 | "\n",
1165 | "print 'Mean absolute error on forecasting: %f' % forecast_mae\n",
1166 | "print 'Mean relative error on forecasting: %f' % forecast_mre"
1167 | ]
1168 | },
1169 | {
1170 | "cell_type": "code",
1171 | "execution_count": 50,
1172 | "metadata": {
1173 | "collapsed": false
1174 | },
1175 | "outputs": [],
1176 | "source": [
1177 | "initialize = False # run some more iterations---i.e., don't reinitialize\n",
1178 | "\n",
1179 | "model.fit(data=masked_train_TV,\n",
1180 | " num_itns=num_itns,\n",
1181 | " verbose=verbose,\n",
1182 | " initialize=initialize)"
1183 | ]
1184 | },
1185 | {
1186 | "cell_type": "code",
1187 | "execution_count": 51,
1188 | "metadata": {
1189 | "collapsed": false
1190 | },
1191 | "outputs": [
1192 | {
1193 | "name": "stdout",
1194 | "output_type": "stream",
1195 | "text": [
1196 | "Mean absolute error on reconstructing observed data: 0.268175\n",
1197 | "Mean relative error on reconstructing observed data: 0.181987\n",
1198 | "Mean absolute error on smoothing missing data: 0.440969\n",
1199 | "Mean relative error on smoothing missing data: 0.315346\n"
1200 | ]
1201 | }
1202 | ],
1203 | "source": [
1204 | "new_pred_TV = model.reconstruct() # reconstruction of the training matrix from new sample\n",
1205 | "avg_pred_TV = (pred_TV + new_pred_TV) / 2. # compute avg. reconstruction\n",
1206 | "\n",
1207 | "train_mae = mae(train_TV[~mask_TV], avg_pred_TV[~mask_TV])\n",
1208 | "train_mre = mre(train_TV[~mask_TV], avg_pred_TV[~mask_TV])\n",
1209 | "\n",
1210 | "smooth_mae = mae(train_TV[mask_TV], avg_pred_TV[mask_TV])\n",
1211 | "smooth_mre = mre(train_TV[mask_TV], avg_pred_TV[mask_TV])\n",
1212 | "\n",
1213 | "print 'Mean absolute error on reconstructing observed data: %f' % train_mae\n",
1214 | "print 'Mean relative error on reconstructing observed data: %f' % train_mre\n",
1215 | "print 'Mean absolute error on smoothing missing data: %f' % smooth_mae\n",
1216 | "print 'Mean relative error on smoothing missing data: %f' % smooth_mre"
1217 | ]
1218 | },
1219 | {
1220 | "cell_type": "code",
1221 | "execution_count": 52,
1222 | "metadata": {
1223 | "collapsed": false
1224 | },
1225 | "outputs": [
1226 | {
1227 | "name": "stdout",
1228 | "output_type": "stream",
1229 | "text": [
1230 | "Mean absolute error on forecasting: 0.320113\n",
1231 | "Mean relative error on forecasting: 0.238774\n"
1232 | ]
1233 | }
1234 | ],
1235 | "source": [
1236 | "new_pred_SV = model.forecast(n_timesteps=S) # forecast of next S time steps from new sample\n",
1237 | "avg_pred_SV = (pred_SV + new_pred_SV) / 2. # compute avg. forecast\n",
1238 | "\n",
1239 | "forecast_mae = mae(data_SV, avg_pred_SV) \n",
1240 | "forecast_mre = mre(data_SV, avg_pred_SV)\n",
1241 | "\n",
1242 | "print 'Mean absolute error on forecasting: %f' % forecast_mae\n",
1243 | "print 'Mean relative error on forecasting: %f' % forecast_mre"
1244 | ]
1245 | }
1246 | ],
1247 | "metadata": {
1248 | "anaconda-cloud": {},
1249 | "kernelspec": {
1250 | "display_name": "Python [Root]",
1251 | "language": "python",
1252 | "name": "Python [Root]"
1253 | },
1254 | "language_info": {
1255 | "codemirror_mode": {
1256 | "name": "ipython",
1257 | "version": 2
1258 | },
1259 | "file_extension": ".py",
1260 | "mimetype": "text/x-python",
1261 | "name": "python",
1262 | "nbconvert_exporter": "python",
1263 | "pygments_lexer": "ipython2",
1264 | "version": "2.7.12"
1265 | }
1266 | },
1267 | "nbformat": 4,
1268 | "nbformat_minor": 0
1269 | }
1270 |
--------------------------------------------------------------------------------
/src/impute.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 |
4 | def init_missing_data(masked_data):
5 | T, V = masked_data.shape
6 | Y_TV = masked_data.astype(np.int32).filled(fill_value=-1)
7 | lam = masked_data.mean() # mean of the observed counts
8 |
9 | for v in xrange(V):
10 | if not masked_data.mask[:, v].any():
11 | continue
12 |
13 | if masked_data.mask[:, v].all():
14 | Y_TV[:, v] = int(np.round(lam))
15 |
16 | else:
17 | # time indices at which data for feature v is missing
18 | time_indices = list(np.where(masked_data.mask[:, v])[0])
19 |
20 | for t in time_indices:
21 | next_obs_y = None
22 | for s in range(t+1, T): # look ahead for an observed val
23 | if not masked_data.mask[s, v]:
24 | next_obs_y = Y_TV[s, v]
25 | break
26 |
27 | prev_obs_y = None
28 | for s in range(t-1, -1, -1): # look behind for an observed val
29 | if not masked_data.mask[s, v]:
30 | prev_obs_y = Y_TV[s, v]
31 | break
32 |
33 | if prev_obs_y is None:
34 | if next_obs_y is None:
35 | # this should only happen when there is only one missing val
36 | assert (~masked_data.mask[:, v]).sum() == 1
37 | lam_tv = lam
38 | else:
39 | lam_tv = next_obs_y
40 | else:
41 | if next_obs_y is None:
42 | lam_tv = prev_obs_y
43 | else:
44 | lam_tv = (next_obs_y + prev_obs_y) / 2.
45 |
46 | Y_TV[t, v] = int(np.round(lam_tv))
47 | assert (Y_TV >= 0).all()
48 | return Y_TV
49 |
--------------------------------------------------------------------------------
/src/lambertw.pxd:
--------------------------------------------------------------------------------
1 | #!python
2 | #cython: boundscheck=False
3 | #cython: cdivision=True
4 | #cython: infertypes=True
5 | #cython: initializedcheck=False
6 | #cython: nonecheck=False
7 | #cython: wraparound=False
8 | #distutils: extra_link_args = ['-lgsl', '-lgslcblas']
9 | #distutils: extra_compile_args = -Wno-unused-function -Wno-unneeded-internal-declaration
10 |
11 | from libc.math cimport exp, log1p
12 |
13 | cdef extern from "gsl/gsl_sf_lambert.h" nogil:
14 | double gsl_sf_lambert_Wm1(double x)
15 |
16 | cpdef double calculate_zeta(double a, double b, double c) nogil
17 |
18 | cdef inline double _calculate_zeta(double a, double b, double c) nogil:
19 | return -a * gsl_sf_lambert_Wm1(-b / (a * exp((b + c) / a))) - b - c
20 |
21 | cpdef double simulate_zeta(double a, double b, double c) nogil
22 |
23 | cdef inline double _simulate_zeta(double a, double b, double c) nogil:
24 | cdef:
25 | int _
26 | double zeta
27 |
28 | zeta = 1.
29 | for _ in range(35):
30 | zeta = a * log1p((zeta + c) / b)
31 | return zeta
32 |
--------------------------------------------------------------------------------
/src/lambertw.pyx:
--------------------------------------------------------------------------------
1 | #!python
2 | #cython: boundscheck=False
3 | #cython: cdivision=True
4 | #cython: infertypes=True
5 | #cython: initializedcheck=False
6 | #cython: nonecheck=False
7 | #cython: wraparound=False
8 | #distutils: extra_link_args = ['-lgsl', '-lgslcblas']
9 | #distutils: extra_compile_args = -Wno-unused-function -Wno-unneeded-internal-declaration
10 |
11 | import sys
12 | import numpy as np
13 | cimport numpy as np
14 | from time import time
15 |
16 | cpdef double calculate_zeta(double a, double b, double c) nogil:
17 | return _calculate_zeta(a, b, c)
18 |
19 | cpdef double simulate_zeta(double a, double b, double c) nogil:
20 | return _simulate_zeta(a, b, c)
21 |
22 |
--------------------------------------------------------------------------------
/src/mcmc_model.pxd:
--------------------------------------------------------------------------------
1 | #!python
2 | #cython: boundscheck=False
3 | #cython: cdivision=True
4 | #cython: infertypes=True
5 | #cython: initializedcheck=False
6 | #cython: nonecheck=False
7 | #cython: wraparound=False
8 | #distutils: language = c
9 | #distutils: extra_link_args = ['-lgsl', '-lgslcblas']
10 | #distutils: extra_compile_args = -Wno-unused-function -Wno-unneeded-internal-declaration
11 |
12 |
13 | cdef extern from "gsl/gsl_rng.h" nogil:
14 | ctypedef struct gsl_rng_type:
15 | pass
16 | ctypedef struct gsl_rng:
17 | pass
18 | gsl_rng_type *gsl_rng_mt19937
19 | gsl_rng *gsl_rng_alloc(gsl_rng_type * T)
20 | void gsl_rng_set(gsl_rng * r, unsigned long int)
21 | void gsl_rng_free(gsl_rng * r)
22 |
23 |
24 | cdef class MCMCModel:
25 | cdef:
26 | gsl_rng *rng
27 | int total_itns, print_every
28 | dict param_list
29 |
30 | cdef list _get_variables(self)
31 | cdef void _generate_state(self)
32 | cdef void _generate_data(self)
33 | cdef void _init_state(self)
34 | cdef void _print_state(self)
35 | cdef void _update(self, int num_itns, int verbose, dict burnin)
36 | cpdef void update(self, int num_itns, int verbose, dict burnin=?)
37 | cdef void _test(self,
38 | int num_samples,
39 | str method=?,
40 | dict var_funcs=?,
41 | dict burnin=?)
42 | cdef void _calc_funcs(self, int n, dict var_funcs, dict out)
43 | cpdef void geweke(self, int num_samples, dict var_funcs=?, dict burnin=?)
44 | cpdef void schein(self, int num_samples, dict var_funcs=?, dict burnin=?)
45 |
--------------------------------------------------------------------------------
/src/mcmc_model.pyx:
--------------------------------------------------------------------------------
1 | #!python
2 | #cython: boundscheck=False
3 | #cython: cdivision=True
4 | #cython: infertypes=True
5 | #cython: initializedcheck=False
6 | #cython: nonecheck=False
7 | #cython: wraparound=False
8 | #distutils: language = c
9 | #distutils: extra_link_args = ['-lgsl', '-lgslcblas']
10 | #distutils: extra_compile_args = -Wno-unused-function -Wno-unneeded-internal-declaration
11 |
12 | import sys
13 | import numpy as np
14 | cimport numpy as np
15 | import scipy.stats as st
16 | from numpy.random import randint
17 | from pp_plot import pp_plot
18 | from copy import deepcopy
19 | from time import time
20 | from contextlib import contextmanager
21 |
22 |
23 |
24 | @contextmanager
25 | def timeit_context(name):
26 | startTime = time()
27 | yield
28 | elapsedTime = time() - startTime
29 | print '%3.4fms: %s' % (elapsedTime * 1000, name)
30 |
31 |
32 | cdef class MCMCModel:
33 |
34 | def __init__(self, object seed=None):
35 |
36 | self.rng = gsl_rng_alloc(gsl_rng_mt19937)
37 |
38 | if seed is None:
39 | seed = randint(0, sys.maxint) & 0xFFFFFFFF
40 | gsl_rng_set(self.rng, seed)
41 |
42 | self.total_itns = 0
43 | self.print_every = 10
44 | self.param_list = {'seed': seed}
45 |
46 | def __dealloc__(self):
47 | """
48 | Free GSL random number generator.
49 | """
50 |
51 | gsl_rng_free(self.rng)
52 |
53 | def get_total_itns(self):
54 | """
55 | Return the number of itns the model has done inference for.
56 | """
57 |
58 | return self.total_itns
59 |
60 | def get_params(self):
61 | """
62 | Get a copy of the initialization params.
63 |
64 | Inheriting objects should add params to the param_list, e.g.:
65 |
66 | cdef class ExampleModel(MCMCModel):
67 |
68 | def __init__(self, double alpha=1., object seed=None):
69 |
70 | super(ExampleModel, self).__init__(seed)
71 |
72 | self.param_list['alpha'] = alpha
73 |
74 | ...
75 | """
76 | return deepcopy(self.param_list)
77 |
78 | cdef list _get_variables(self):
79 | """
80 | Return variable names, values, and sampling methods for testing.
81 |
82 | Example:
83 |
84 | return [('foo', self.foo, self._sample_foo),
85 | ('bar', self.bar, self._sample_bar)]
86 | """
87 | pass
88 |
89 | def get_state(self):
90 | """
91 | Wrapper around _get_variables(...).
92 |
93 | Returns only the names and values of variables (not update funcs).
94 | """
95 | for k, v, update_func in self._get_variables():
96 | if np.isscalar(v):
97 | yield k, v
98 | else:
99 | yield k, np.array(v)
100 |
101 | cdef void _generate_state(self):
102 | """
103 | Generate internal state.
104 | """
105 |
106 | pass
107 |
108 | cdef void _generate_data(self):
109 | """
110 | Generate data given internal state.
111 | """
112 |
113 | pass
114 |
115 | cdef void _init_state(self):
116 | """
117 | Initialize internal state.
118 | """
119 |
120 | pass
121 |
122 | cdef void _print_state(self):
123 | """
124 | Print internal state.
125 | """
126 | cdef:
127 | double t
128 |
129 | print 'ITERATION %d\n' % self.total_itns
130 |
131 | cdef void _update(self, int num_itns, int verbose, dict burnin):
132 | """
133 | Perform inference.
134 | """
135 |
136 | cdef:
137 | int n
138 |
139 | for n in range(num_itns):
140 | for k, _, update_func in self._get_variables():
141 | if k not in burnin.keys() or n >= burnin[k]:
142 | if (verbose == 1) and ((n + 1) % self.print_every == 0):
143 | with timeit_context('sampling %s' % k):
144 | update_func(self)
145 | else:
146 | update_func(self)
147 | self.total_itns += 1
148 | if (verbose == 1) and ((n + 1) % self.print_every == 0):
149 | self._print_state()
150 |
151 | cpdef void update(self, int num_itns, int verbose, dict burnin={}):
152 | """
153 | Thin wrapper around _update(...).
154 | """
155 | self._update(num_itns, verbose, burnin)
156 |
157 | cdef void _test(self,
158 | int num_samples,
159 | str method='geweke',
160 | dict var_funcs={},
161 | dict burnin={}):
162 | """
163 | Perform Geweke testing or Schein testing.
164 | """
165 |
166 | cdef:
167 | int n
168 | dict default_funcs, fwd, rev
169 |
170 | default_funcs = {'Arith. Mean': np.mean,
171 | 'Geom. Mean': lambda x: np.exp(np.mean(np.log1p(x))),
172 | 'Var.': np.var,
173 | 'Max.': np.max}
174 |
175 | fwd, rev = {}, {}
176 | var_funcs = deepcopy(var_funcs) # this method changes var_funcs state
177 | for k, v, _ in self._get_variables():
178 |
179 | if k not in burnin.keys():
180 | burnin[k] = 0
181 |
182 | if burnin[k] > num_samples:
183 | if k in var_funcs.keys():
184 | del var_funcs[k]
185 | continue
186 |
187 | if k not in var_funcs.keys():
188 | var_funcs[k] = default_funcs
189 | assert len(var_funcs[k].keys()) <= 4
190 |
191 | if np.isscalar(v):
192 | fwd[k] = np.empty(num_samples)
193 | rev[k] = np.empty(num_samples)
194 | else:
195 | fwd[k] = {}
196 | rev[k] = {}
197 | for f in var_funcs[k]:
198 | fwd[k][f] = np.empty(num_samples)
199 | rev[k][f] = np.empty(num_samples)
200 |
201 | if method == 'schein':
202 | for n in range(num_samples):
203 | self._generate_state()
204 | self._generate_data()
205 | self._calc_funcs(n, var_funcs, fwd)
206 |
207 | self._update(10, 0, burnin)
208 | self._generate_data()
209 | self._calc_funcs(n, var_funcs, rev)
210 | if n % 500 == 0:
211 | print n
212 | else:
213 | for n in range(num_samples):
214 | self._generate_state()
215 | self._generate_data()
216 | self._calc_funcs(n, var_funcs, fwd)
217 | if n % 500 == 0:
218 | print n
219 |
220 | self._generate_state()
221 | for n in range(num_samples):
222 | self._generate_data()
223 | self._update(10, 0, burnin)
224 | self._calc_funcs(n, var_funcs, rev)
225 | if n % 500 == 0:
226 | print n
227 |
228 | for k, _, _ in self._get_variables():
229 | if not burnin[k] > num_samples:
230 | pp_plot(fwd[k], rev[k], k)
231 |
232 | cdef void _calc_funcs(self,
233 | int n,
234 | dict var_funcs,
235 | dict out):
236 | """
237 | Helper function for _test. Calculates and stores functions of variables.
238 | """
239 |
240 | for k, v, _ in self._get_variables():
241 | if k not in var_funcs.keys():
242 | continue
243 | if np.isscalar(v):
244 | out[k][n] = v
245 | else:
246 | for f, func in var_funcs[k].iteritems():
247 | out[k][f][n] = func(v)
248 |
249 | cpdef void geweke(self,
250 | int num_samples,
251 | dict var_funcs={},
252 | dict burnin={}):
253 | """
254 | Wrapper around _test(...).
255 | """
256 | self._test(num_samples=num_samples,
257 | method='geweke',
258 | var_funcs=var_funcs,
259 | burnin=burnin)
260 |
261 | cpdef void schein(self,
262 | int num_samples,
263 | dict var_funcs={},
264 | dict burnin={}):
265 | """
266 | Wrapper around _test(...).
267 | """
268 | self._test(num_samples=num_samples,
269 | method='schein',
270 | var_funcs=var_funcs,
271 | burnin=burnin)
272 |
--------------------------------------------------------------------------------
/src/pgds.pyx:
--------------------------------------------------------------------------------
1 | #!python
2 | #cython: boundscheck=False
3 | #cython: cdivision=True
4 | #cython: infertypes=True
5 | #cython: initializedcheck=False
6 | #cython: nonecheck=False
7 | #cython: wraparound=False
8 | #distutils: extra_link_args = ['-lgsl', '-lgslcblas']
9 | #distutils: extra_compile_args = -Wno-unused-function -Wno-unneeded-internal-declaration
10 |
11 | import sys
12 | import numpy as np
13 | cimport numpy as np
14 | from libc.math cimport exp, log, log1p, expm1
15 |
16 | from mcmc_model cimport MCMCModel
17 | from sample cimport _sample_gamma, _sample_dirichlet, _sample_lnbeta,\
18 | _sample_crt, _sample_trunc_poisson
19 |
20 | from lambertw cimport _simulate_zeta
21 | from impute import init_missing_data
22 |
23 | cdef extern from "gsl/gsl_rng.h" nogil:
24 | ctypedef struct gsl_rng:
25 | pass
26 |
27 | cdef extern from "gsl/gsl_randist.h" nogil:
28 | double gsl_rng_uniform(gsl_rng * r)
29 | unsigned int gsl_ran_poisson(gsl_rng * r, double mu)
30 | void gsl_ran_multinomial(gsl_rng * r,
31 | size_t K,
32 | unsigned int N,
33 | const double p[],
34 | unsigned int n[])
35 |
36 | cdef extern from "gsl/gsl_sf_psi.h" nogil:
37 | double gsl_sf_psi(double)
38 |
39 |
40 | cdef class PGDS(MCMCModel):
41 |
42 | cdef:
43 | int T, V, K, P, shrink, stationary, steady, binary, y_
44 | double tau, gam, beta, eps, theta_, nu_
45 | double[::1] nu_K, xi_K, delta_T, zeta_T, P_K, lnq_K, Theta_T, shp_V
46 | double[:,::1] Pi_KK, Theta_TK, shp_KK, Phi_KV, R_TK
47 | int[::1] Y_T, vals_P, L_K
48 | int[:,::1] Y_TV, Y_TK, Y_KV, L_TK, L_KK, H_KK, mask_TV, data_TV, subs_P2
49 | int[:,:,::1] L_TKK
50 | unsigned int[::1] N_K, N_V
51 |
52 | def __init__(self, int T, int V, int K, double eps=0.1, double gam=10.,
53 | double tau=1., int shrink=1, int stationary=0, int steady=0,
54 | int binary=0, object seed=None):
55 |
56 | self.print_every = 25
57 | super(PGDS, self).__init__(seed)
58 |
59 | self.T = self.param_list['T'] = T
60 | self.V = self.param_list['V'] = V
61 | self.K = self.param_list['K'] = K
62 | self.eps = self.param_list['eps'] = eps
63 | self.gam = self.param_list['gam'] = gam
64 | self.tau = self.param_list['tau'] = tau
65 | self.binary = self.param_list['binary'] = binary
66 | self.shrink = self.param_list['shrink'] = shrink
67 | self.stationary = self.param_list['stationary'] = stationary
68 | self.steady = self.param_list['steady'] = steady
69 | if steady == 1 and stationary == 0:
70 | raise ValueError('Steady-state only valid for stationary model.')
71 |
72 | self.beta = 1.
73 | self.nu_K = np.zeros(K)
74 | self.xi_K = np.zeros(K)
75 | self.Pi_KK = np.zeros((K, K))
76 | self.Theta_TK = np.zeros((T, K))
77 | self.Theta_T = np.zeros(T)
78 | self.theta_ = 0
79 | self.delta_T = np.zeros(T)
80 | self.Phi_KV = np.zeros((K, V))
81 |
82 | self.zeta_T = np.zeros(T)
83 | self.shp_KK = np.zeros((K, K))
84 | self.R_TK = np.zeros((T, K))
85 | self.H_KK = np.zeros((K, K), dtype=np.int32)
86 | self.L_K = np.zeros(K, dtype=np.int32)
87 | self.L_KK = np.zeros((K, K), dtype=np.int32)
88 | self.L_TK = np.zeros((T, K), dtype=np.int32)
89 | self.L_TKK = np.zeros((T, K, K), dtype=np.int32)
90 | self.Y_TV = np.zeros((T, V), dtype=np.int32)
91 | self.Y_TK = np.zeros((T, K), dtype=np.int32)
92 | self.Y_KV = np.zeros((K, V), dtype=np.int32)
93 | self.Y_T = np.zeros(T, dtype=np.int32)
94 | self.y_ = 0
95 | self.P_K = np.zeros(K)
96 | self.lnq_K = np.zeros(K)
97 | self.N_K = np.zeros(K, dtype=np.uint32)
98 | self.N_V = np.zeros(V, dtype=np.uint32)
99 | self.shp_V = np.zeros(V)
100 |
101 | self.data_TV = np.zeros((T, V), dtype=np.int32)
102 | self.mask_TV = np.zeros((T, V), dtype=np.int32)
103 |
104 | self.P = 0 # placeholder
105 | self.subs_P2 = np.zeros((self.P, 2), dtype=np.int32)
106 | self.vals_P = np.zeros(self.P, dtype=np.int32)
107 |
108 | def fit(self, data, num_itns=1000, verbose=True, initialize=True, burnin={}):
109 | if not isinstance(data, np.ma.core.MaskedArray):
110 | data = np.ma.array(data, mask=None)
111 |
112 | assert data.shape == (self.T, self.V)
113 | assert (data >= 0).all()
114 | if self.binary == 1:
115 | assert (data <= 1).all()
116 |
117 | filled_data = data.astype(np.int32).filled(fill_value=-1)
118 | subs = filled_data.nonzero()
119 | self.vals_P = filled_data[subs] # missing values will be -1
120 | self.subs_P2 = np.array(zip(*subs), dtype=np.int32)
121 | self.P = self.vals_P.shape[0]
122 |
123 | if self.binary == 0:
124 | # filled_data[filled_data == -1] = 0
125 | # self.Y_TV = np.ascontiguousarray(filled_data, dtype=np.int32)
126 | self.Y_TV = np.ascontiguousarray(init_missing_data(data))
127 | self.Y_T = np.sum(self.Y_TV, axis=1, dtype=np.int32)
128 | self.y_ = np.sum(self.Y_T)
129 |
130 | if initialize:
131 | self._init_state()
132 |
133 | self._update(num_itns=num_itns, verbose=int(verbose), burnin=burnin)
134 |
135 | def reconstruct(self, subs=(), partial_state={}):
136 | Theta_TK = np.array(self.Theta_TK)
137 | if 'Theta_TK' in partial_state.keys():
138 | Theta_TK = partial_state['Theta_TK']
139 |
140 | Phi_KV = np.array(self.Phi_KV)
141 | if 'Phi_KV' in partial_state.keys():
142 | Phi_KV = partial_state['Phi_KV']
143 |
144 | delta_T = np.array(self.delta_T)
145 | if 'delta_T' in partial_state.keys():
146 | delta_T = partial_state['delta_T']
147 |
148 | if not subs:
149 | rates_TV = np.einsum('tk,t,kv->tv', Theta_TK, delta_T, Phi_KV)
150 | if self.binary == 1:
151 | rates_TV = -np.expm1(-rates_TV)
152 | return rates_TV
153 |
154 | else:
155 | Theta_PK = Theta_TK[subs[0]]
156 | delta_P = delta_T[subs[0]]
157 | Phi_KP = Phi_KV[:, subs[1]]
158 | rates_P = np.einsum('pk,kp,p->p', Theta_PK, Phi_KP, delta_P)
159 | if self.binary == 1:
160 | rates_P = -np.expm1(-rates_P)
161 | return rates_P
162 |
163 | cdef void _forecast(self, double[:,:,::1] rates_CSV, str mode='arithmetic'):
164 | cdef:
165 | int C, S, K, c, s, k, v
166 | double[:,::1] Theta_SK
167 | double[::1] delta_S, shp_K
168 | double rte, delta, tau, mu_csv
169 |
170 | C = rates_CSV.shape[0]
171 | S = rates_CSV.shape[1]
172 | K = self.K
173 |
174 | delta_S = np.zeros(S)
175 | if self.stationary == 1:
176 | delta_S[:] = self.delta_T[0]
177 | else:
178 | delta_S[:] = np.mean(self.delta_T[self.T-2:])
179 |
180 | Theta_SK = np.zeros((S, K))
181 |
182 | tau = self.tau
183 | rte = 1. / tau
184 | for c in range(C):
185 | for s in xrange(S):
186 | if s == 0:
187 | shp_K = tau * np.dot(self.Theta_TK[self.T-1], self.Pi_KK)
188 | else:
189 | shp_K = tau * np.dot(Theta_SK[s-1], self.Pi_KK)
190 |
191 | delta = delta_S[s]
192 | for k in range(K):
193 |
194 | if mode == 'arithmetic':
195 | Theta_SK[s, k] = shp_K[k] / rte
196 | elif mode == 'geometric':
197 | Theta_SK[s, k] = exp(gsl_sf_psi(shp_K[k]) - log(rte))
198 | else:
199 | Theta_SK[s, k] = _sample_gamma(self.rng, shp_K[k], rte)
200 |
201 | for v in range(self.V):
202 | mu_csv = delta * np.dot(Theta_SK[s], self.Phi_KV[:, v])
203 | if self.binary == 1:
204 | rates_CSV[c, s, v] = -expm1(-mu_csv)
205 | else:
206 | rates_CSV[c, s, v] = mu_csv
207 |
208 | def forecast(self, n_timesteps=1, n_chains=1, mode='arithmetic'):
209 | assert mode in ['arithmetic', 'geometric', 'sample']
210 | if mode != 'sample':
211 | assert n_chains == 1
212 |
213 | rates_CSV = np.zeros((n_chains, n_timesteps, self.V))
214 | self._forecast(rates_CSV=rates_CSV, mode=mode)
215 | return rates_CSV[0] if n_chains == 1 else rates_CSV
216 |
217 | cdef list _get_variables(self):
218 | """
219 | Return variable names, values, and sampling methods for testing.
220 | """
221 | variables = [('Y_KV', self.Y_KV, self._update_Y_TVK),
222 | ('Y_TK', self.Y_TK, lambda x: None),
223 | ('L_TKK', self.L_TKK, self._update_L_TKK)]
224 |
225 | if self.stationary == 0:
226 | variables += [('delta_T', self.delta_T, self._update_delta_T)]
227 | else:
228 | variables += [('delta_T', self.delta_T[0], self._update_delta_T)]
229 |
230 | variables += [('Phi_KV', self.Phi_KV, self._update_Phi_KV),
231 | ('Theta_TK', self.Theta_TK, self._update_Theta_TK),
232 | ('Pi_KK', self.Pi_KK, self._update_Pi_KK)]
233 |
234 | if self.shrink == 0:
235 | variables += [('beta', self.beta, self._update_beta),
236 | ('nu_K', self.nu_K, self._update_nu_K)]
237 | else:
238 | variables += [('beta', self.beta, self._update_beta),
239 | ('H_KK', self.H_KK, self._update_H_KK_and_lnq_K),
240 | ('lnq_K', self.lnq_K, lambda x: None),
241 | ('nu_K', self.nu_K, self._update_nu_K),
242 | ('xi_K', self.xi_K[0], self._update_xi_K)]
243 | return variables
244 |
245 | cdef void _init_state(self):
246 | cdef:
247 | int K, k, k2, t
248 | double eps, tau, shape, theta_tk, nu_k
249 | gsl_rng * rng
250 |
251 | K = self.K
252 | eps = self.eps
253 | tau = self.tau
254 | rng = self.rng
255 |
256 | self.delta_T[:] = 1.
257 | self._update_zeta_T()
258 |
259 | self.beta = 1.
260 | self.tau = 1.
261 |
262 | self.xi_K[:] = 1.
263 |
264 | self.nu_ = 0
265 | for k in range(K):
266 | nu_k = _sample_gamma(rng, self.gam / float(K), 1. / self.beta)
267 | self.nu_K[k] = nu_k
268 | self.nu_ += nu_k
269 |
270 | for k in range(K):
271 | for k2 in range(K):
272 | self.shp_KK[k, k2] = _sample_gamma(rng, 1., 1.)
273 | _sample_dirichlet(rng, self.shp_KK[k], self.Pi_KK[k])
274 | assert np.isfinite(self.Pi_KK[k]).all()
275 | assert self.Pi_KK[k, 0] >= 0
276 |
277 | self.theta_ = 0
278 | self.Theta_T[:] = 0
279 | for t in range(self.T):
280 | for k in range(self.K):
281 | theta_tk = _sample_gamma(rng, 1., 1.)
282 | self.Theta_TK[t, k] = theta_tk
283 | self.Theta_T[t] += theta_tk
284 | self.theta_ += theta_tk
285 |
286 | for k in range(self.K):
287 | _sample_dirichlet(rng, np.ones(self.V), self.Phi_KV[k])
288 | assert self.Phi_KV[k, 0] >= 0
289 |
290 | cdef void _print_state(self):
291 | cdef:
292 | int l
293 | double theta, delta
294 |
295 | l = np.sum(self.L_KK)
296 | theta = np.mean(self.Theta_TK)
297 | delta = np.mean(self.delta_T)
298 | print 'ITERATION %d: total aux counts: %d\t \
299 | mean theta: %.4f\t mean delta: %f\n' % \
300 | (self.total_itns, l, theta, delta)
301 |
302 | cdef void _generate_state(self):
303 | """
304 | Generate internal state.
305 | """
306 | cdef:
307 | int K, k, k1, k2, t
308 | double eps, tau, shape, theta_tk, nu_k
309 | gsl_rng * rng
310 |
311 | K = self.K
312 | eps = self.eps
313 | tau = self.tau
314 | rng = self.rng
315 |
316 | self.beta = _sample_gamma(rng, eps, 1. / eps)
317 |
318 | self.nu_ = 0
319 | for k in range(K):
320 | nu_k = _sample_gamma(rng, self.gam / K, 1. / self.beta)
321 | assert np.isfinite(nu_k)
322 | self.nu_K[k] = nu_k
323 | self.nu_ += nu_k
324 |
325 | self.xi_K[:] = 1
326 | if self.shrink == 1:
327 | self.xi_K[:] = _sample_gamma(rng, eps, 1. / eps)
328 | assert np.isfinite(self.xi_K).all()
329 |
330 | for k in range(K):
331 | self.shp_KK[k, :] = self.eps
332 | if self.shrink == 1:
333 | self.shp_KK[k, k] = self.nu_K[k] * self.xi_K[k]
334 | for k2 in range(K):
335 | if k == k2:
336 | continue
337 | self.shp_KK[k, k2] = self.nu_K[k] * self.nu_K[k2]
338 | _sample_dirichlet(rng, self.shp_KK[k], self.Pi_KK[k])
339 | assert np.isfinite(self.Pi_KK[k]).all()
340 | assert self.Pi_KK[k, 0] >= 0
341 |
342 | self.theta_ = 0
343 | self.Theta_T[:] = 0
344 | for k in range(K):
345 | shape = tau * self.nu_K[k]
346 | theta_tk = _sample_gamma(rng, shape, 1. / tau)
347 | self.Theta_TK[0, k] = theta_tk
348 | self.Theta_T[0] += theta_tk
349 | self.theta_ += theta_tk
350 |
351 | for t in range(1, self.T):
352 | for k in range(K):
353 | shape = tau * np.dot(self.Theta_TK[t-1], self.Pi_KK[:, k])
354 | theta_tk = _sample_gamma(rng, shape, 1. / tau)
355 | self.Theta_TK[t, k] = theta_tk
356 | self.Theta_T[t] += theta_tk
357 | self.theta_ += theta_tk
358 |
359 | for k in range(self.K):
360 | _sample_dirichlet(rng, np.ones(self.V) * eps, self.Phi_KV[k])
361 | assert self.Phi_KV[k, 0] >= 0
362 |
363 | if self.stationary == 0:
364 | for t in range(self.T):
365 | self.delta_T[t] = _sample_gamma(rng, eps, 1. / eps)
366 | else:
367 | self.delta_T[:] = _sample_gamma(rng, eps, 1. / eps)
368 | self._update_zeta_T()
369 |
370 | cdef void _generate_data(self):
371 | """
372 | Generate data given internal state.
373 | """
374 | cdef:
375 | int t, v, k
376 | unsigned int y_tk, y_tvk
377 | double mu_tk
378 | tuple subs
379 |
380 | self.y_ = 0
381 | self.Y_T[:] = 0
382 | self.Y_TV[:] = 0
383 | self.Y_TK[:] = 0
384 | self.Y_KV[:] = 0
385 |
386 | for t in range(self.T):
387 | for k in range(self.K):
388 | mu_tk = self.Theta_TK[t, k] * self.delta_T[t]
389 | y_tk = gsl_ran_poisson(self.rng, mu_tk)
390 |
391 | if y_tk > 0:
392 |
393 | self.Y_TK[t, k] = y_tk
394 | self.Y_T[t] += y_tk
395 | self.y_ += y_tk
396 |
397 | gsl_ran_multinomial(self.rng,
398 | self.V,
399 | y_tk,
400 | &self.Phi_KV[k, 0],
401 | &self.N_V[0])
402 | for v in range(self.V):
403 | y_tvk = self.N_V[v]
404 | if y_tvk > 0:
405 | self.Y_KV[k, v] += y_tvk
406 | self.Y_TV[t, v] += y_tvk
407 |
408 | if self.y_ > 0:
409 | subs = np.nonzero(self.Y_TV)
410 | self.subs_P2 = np.array(zip(*subs), dtype=np.int32)
411 | self.P = self.subs_P2.shape[0]
412 | if self.binary == 0:
413 | self.vals_P = np.array(self.Y_TV)[subs]
414 | else:
415 | self.P = 0
416 |
417 | self._update_L_TKK()
418 | self._update_H_KK_and_lnq_K()
419 |
420 | cdef void _update_zeta_T(self) nogil:
421 | cdef:
422 | int t
423 | double tmp
424 |
425 | if self.steady == 1:
426 | self.zeta_T[:] = _simulate_zeta(self.tau, self.tau, self.delta_T[1])
427 | else:
428 | self.zeta_T[self.T-1] = 0
429 | for t in range(self.T-2,-1,-1):
430 | tmp = (self.zeta_T[t+1] + self.delta_T[t+1]) / self.tau
431 | self.zeta_T[t] = self.tau * log1p(tmp)
432 |
433 | cdef void _update_Y_TVK(self) nogil:
434 | cdef:
435 | int p, t, v, k, y_tv
436 | double norm, u, mu_tv
437 | unsigned int y_tvk
438 |
439 | self.Y_TK[:] = 0
440 | self.Y_KV[:] = 0
441 | for p in range(self.P):
442 | t = self.subs_P2[p, 0]
443 | v = self.subs_P2[p, 1]
444 | y_tv = self.Y_TV[t, v]
445 |
446 | if (self.vals_P[p] == -1) or (self.binary == 1):
447 | self.y_ -= y_tv
448 | self.Y_T[t] -= y_tv
449 |
450 | mu_tv = 0
451 | for k in range(self.K):
452 | mu_tv += self.Theta_TK[t, k] * self.Phi_KV[k, v]
453 | mu_tv *= self.delta_T[t]
454 |
455 | if (self.vals_P[p] == -1) and (self.total_itns > 500):
456 | y_tv = gsl_ran_poisson(self.rng, mu_tv)
457 | else:
458 | y_tv = _sample_trunc_poisson(self.rng, mu_tv)
459 |
460 | self.Y_TV[t, v] = y_tv
461 | self.Y_T[t] += y_tv
462 | self.y_ += y_tv
463 |
464 | if y_tv == 0:
465 | continue
466 |
467 | for k in range(self.K):
468 | self.P_K[k] = self.Theta_TK[t, k] * self.Phi_KV[k, v]
469 |
470 | gsl_ran_multinomial(self.rng,
471 | self.K,
472 | y_tv,
473 | &self.P_K[0],
474 | &self.N_K[0])
475 |
476 | for k in range(self.K):
477 | y_tvk = self.N_K[k]
478 | if y_tvk > 0:
479 | self.Y_TK[t, k] += y_tvk
480 | self.Y_KV[k, v] += y_tvk
481 |
482 | cdef void _update_L_TKK(self):
483 | cdef:
484 | int t, k, k1, m, l_tk
485 | double norm, mu, zeta
486 | unsigned int l_tkk
487 |
488 | self.L_K[:] = 0
489 | self.L_KK[:] = 0
490 | self.L_TK[:] = 0
491 | self.L_TKK[:] = 0
492 |
493 | self.R_TK = self.tau * np.dot(self.Theta_TK, self.Pi_KK)
494 |
495 | with nogil:
496 | if self.steady == 1:
497 | zeta = self.zeta_T[self.T-1]
498 | for k in range(self.K):
499 | mu = zeta * self.Theta_TK[self.T-1, k]
500 | self.L_TK[self.T-1, k] = gsl_ran_poisson(self.rng, mu)
501 |
502 | for t in range(self.T-2, -1, -1):
503 | for k in range(self.K):
504 | m = self.Y_TK[t+1, k] + self.L_TK[t+1, k]
505 |
506 | if m == 0: # l_tk = 0 if m = 0
507 | continue
508 |
509 | l_tk = _sample_crt(self.rng, m, self.R_TK[t, k])
510 | # assert (l_tk >= 0) and (l_tk <= m)
511 |
512 | if l_tk == 0:
513 | continue
514 |
515 | for k1 in range(self.K):
516 | self.P_K[k1] = self.Theta_TK[t, k1] * self.Pi_KK[k1, k]
517 |
518 | gsl_ran_multinomial(self.rng,
519 | self.K,
520 | l_tk,
521 | &self.P_K[0],
522 | &self.N_K[0])
523 |
524 | for k1 in range(self.K):
525 | l_tkk = self.N_K[k1]
526 | if l_tkk > 0:
527 | self.L_K[k1] += l_tkk
528 | self.L_KK[k1, k] += l_tkk
529 | self.L_TK[t, k1] += l_tkk
530 | self.L_TKK[t, k1, k] = l_tkk
531 |
532 | cdef void _update_Theta_TK(self) nogil:
533 | cdef:
534 | int k, k1, t
535 | double shp, sca, tau, theta_tk
536 |
537 | tau = self.tau
538 |
539 | self.theta_ = 0
540 | self.Theta_T[:] = 0
541 | for t in range(self.T):
542 | sca = 1. / (tau + self.zeta_T[t] + self.delta_T[t])
543 |
544 | for k in range(self.K):
545 | if t == 0:
546 | shp = tau * self.nu_K[k]
547 | else:
548 | shp = 0
549 | for k1 in range(self.K):
550 | shp += self.Theta_TK[t-1, k1] * self.Pi_KK[k1, k]
551 | shp *= tau
552 | shp += self.L_TK[t, k] + self.Y_TK[t, k]
553 |
554 | theta_tk = _sample_gamma(self.rng, shp, sca)
555 | self.Theta_TK[t, k] = theta_tk
556 | self.Theta_T[t] += theta_tk
557 | self.theta_ += theta_tk
558 |
559 | cdef void _update_Phi_KV(self) nogil:
560 | cdef:
561 | int k, v
562 | double eps
563 | gsl_rng * rng
564 |
565 | eps = self.eps
566 | rng = self.rng
567 |
568 | for k in range(self.K):
569 | for v in range(self.V):
570 | self.shp_V[v] = eps + self.Y_KV[k, v]
571 | _sample_dirichlet(rng, self.shp_V, self.Phi_KV[k])
572 | # assert self.Phi_KV[k, 0] >= 0
573 |
574 | cdef void _update_delta_T(self) nogil:
575 | cdef:
576 | int t
577 | double shp, rte
578 |
579 | if self.stationary == 0:
580 | for t in range(self.T):
581 | shp = self.eps + self.Y_T[t]
582 | rte = self.eps + self.Theta_T[t]
583 | self.delta_T[t] = _sample_gamma(self.rng, shp, 1. / rte)
584 | else:
585 | shp = self.eps + self.y_
586 | rte = self.eps + self.theta_
587 | self.delta_T[:] = _sample_gamma(self.rng, shp, 1. / rte)
588 |
589 | self._update_zeta_T()
590 |
591 | cdef void _update_Pi_KK(self) nogil:
592 | cdef:
593 | int k, k2
594 | double nu_k
595 |
596 | for k in range(self.K):
597 | if self.shrink == 1:
598 | nu_k = self.nu_K[k]
599 | self.shp_KK[k, k] = nu_k * self.xi_K[k] + self.L_KK[k, k]
600 | for k2 in range(self.K):
601 | if k == k2:
602 | continue
603 | self.shp_KK[k, k2] = nu_k * self.nu_K[k2] + self.L_KK[k, k2]
604 | else:
605 | for k2 in range(self.K):
606 | self.shp_KK[k, k2] = self.eps + self.L_KK[k, k2]
607 |
608 | for k in range(self.K):
609 | _sample_dirichlet(self.rng, self.shp_KK[k], self.Pi_KK[k])
610 | # assert np.isfinite(self.Pi_KK[k]).all()
611 |
612 | cdef void _update_H_KK_and_lnq_K(self):
613 | cdef:
614 | int k, k2, l_k
615 | double nu_k, xi_k, tmp, r
616 |
617 | self.lnq_K[:] = 0
618 | for k in range(self.K):
619 | nu_k = self.nu_K[k]
620 | xi_k = self.xi_K[k]
621 | r = nu_k * xi_k
622 | self.H_KK[k, k] = _sample_crt(self.rng, self.L_KK[k, k], r)
623 | assert self.H_KK[k, k] >= 0
624 |
625 | for k2 in range(self.K):
626 | if k2 == k:
627 | continue
628 | r = nu_k * self.nu_K[k2]
629 | self.H_KK[k, k2] = _sample_crt(self.rng, self.L_KK[k, k2], r)
630 | assert self.H_KK[k, k2] >= 0
631 |
632 | l_k = self.L_K[k]
633 | if l_k > 0:
634 | tmp = (xi_k + self.nu_ - nu_k)
635 | self.lnq_K[k] = _sample_lnbeta(self.rng, nu_k * tmp, l_k)
636 | assert np.isfinite(self.lnq_K[k])
637 | assert self.lnq_K[k] <= 0
638 |
639 | cdef void _update_nu_K(self):
640 | cdef:
641 | int k, k2, m_k, l_0k
642 | double tau, zeta, gam_k, nu_k, shp, rte
643 | # list indices
644 |
645 | tau = self.tau
646 | zeta = tau * log1p((self.delta_T[0] + self.zeta_T[0]) / tau)
647 | gam_k = self.gam / self.K
648 |
649 | # indices = range(self.K)
650 | # np.random.shuffle(indices)
651 |
652 | for k in range(self.K):
653 | nu_k = self.nu_K[k]
654 | m_k = self.Y_TK[0, k] + self.L_TK[0, k]
655 | l_0k = _sample_crt(self.rng, m_k, tau * nu_k)
656 | assert l_0k >= 0
657 |
658 | shp = gam_k + l_0k
659 | rte = self.beta + zeta
660 |
661 | if self.shrink == 1:
662 | shp += self.H_KK[k, k]
663 | rte += (self.xi_K[k] + self.nu_ - nu_k) * (-self.lnq_K[k])
664 | for k2 in range(self.K):
665 | if k2 == k:
666 | continue
667 | shp += self.H_KK[k, k2] + self.H_KK[k2, k]
668 | rte += self.nu_K[k2] * (-self.lnq_K[k2])
669 |
670 | assert np.isfinite(shp) and np.isfinite(rte)
671 | assert shp >= 0 and rte >= 0
672 |
673 | self.nu_ -= nu_k
674 | nu_k = _sample_gamma(self.rng, shp, 1. / rte)
675 | self.nu_K[k] = nu_k
676 | self.nu_ += nu_k
677 |
678 | cdef void _update_xi_K(self) nogil:
679 |
680 | cdef:
681 | int k
682 | double shp, rte
683 |
684 | shp = rte = self.eps
685 | for k in range(self.K):
686 | shp += self.H_KK[k, k]
687 | rte -= self.nu_K[k] * self.lnq_K[k]
688 | # assert np.isfinite(shp) and np.isfinite(rte)
689 |
690 | self.xi_K[:] = _sample_gamma(self.rng, shp, 1. / rte)
691 |
692 | cdef void _update_beta(self) nogil:
693 | cdef:
694 | double shp, rte
695 |
696 | shp = self.eps + self.gam
697 | rte = self.eps + self.nu_
698 | self.beta = _sample_gamma(self.rng, shp, 1. / rte)
699 |
--------------------------------------------------------------------------------
/src/pp_plot.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import matplotlib.pyplot as plt
3 | import seaborn as sns
4 |
5 |
6 | gray = (102/255.0, 102/255.0, 102/255.0, 1.0)
7 | light_gray = (238/255.0, 238/255.0, 238/255.0, 1.0)
8 |
9 | sns.set_style({'font.family': 'Abel'})
10 | sns.set_style({'axes.facecolor': light_gray})
11 | sns.set_style({'xtick.color': gray})
12 | sns.set_style({'text.color': gray})
13 | sns.set_style({'ytick.color': gray})
14 | sns.set_style({'axes.grid': False})
15 |
16 |
17 | def _cdf(data):
18 | """
19 | Returns the empirical CDF (a function) for the specified data.
20 |
21 | Arguments:
22 |
23 | data -- data from which to compute the CDF
24 | """
25 |
26 | tmp = np.empty_like(data)
27 | tmp[:] = data
28 | tmp.sort()
29 |
30 | def f(x):
31 | return np.searchsorted(tmp, x, 'right') / float(len(tmp))
32 |
33 | return f
34 |
35 |
36 | def pp_plot(a, b, t=None):
37 | """
38 | Generates a P-P plot.
39 | """
40 |
41 | if isinstance(a, dict):
42 | assert isinstance(b, dict) and a.keys() == b.keys()
43 | for n, (k, v) in enumerate(a.iteritems()):
44 | plt.subplot(221 + n)
45 | x = np.sort(np.asarray(v))
46 | if len(x) > 10000:
47 | step = len(x) / 5000
48 | x = x[::step]
49 | plt.plot(_cdf(v)(x), _cdf(b[k])(x), lw=3, alpha=0.7)
50 | plt.plot([0, 1], [0, 1], ':', c='k', lw=4, alpha=0.7)
51 | if t is not None:
52 | plt.title(t + ' (' + k + ')')
53 | plt.tight_layout()
54 | plt.show()
55 | else:
56 | x = np.sort(np.asarray(a))
57 | if len(x) > 10000:
58 | step = len(x) / 5000
59 | x = x[::step]
60 | plt.plot(_cdf(a)(x), _cdf(b)(x), lw=3, alpha=0.7)
61 | plt.plot([0, 1], [0, 1], ':', c='k', lw=4, alpha=0.7)
62 | if t is not None:
63 | plt.title(t)
64 | plt.tight_layout()
65 | plt.show()
66 |
67 |
68 | def test(num_samples=100000):
69 | """
70 | Test code.
71 | """
72 |
73 | a = np.random.normal(20.0, 5.0, num_samples)
74 | b = np.random.normal(20.0, 5.0, num_samples)
75 | pp_plot(a, b)
76 |
77 |
78 | if __name__ == '__main__':
79 | test()
80 |
--------------------------------------------------------------------------------
/src/run_pgds.py:
--------------------------------------------------------------------------------
1 | import sys
2 |
3 | import cPickle as pickle
4 | import numpy as np
5 | import numpy.random as rn
6 |
7 | from argparse import ArgumentParser
8 | from path import path
9 | from time import sleep
10 |
11 | from pgds import PGDS
12 |
13 |
14 | def get_train_forecast_split(data, mask):
15 | assert mask.shape == data.shape
16 | if not mask[-1].all():
17 | return data, None
18 | else:
19 | for S in xrange(data.shape[0]):
20 | if not mask[-(S+1):].all():
21 | break
22 | assert S >= 1
23 | train_data = data[:-S]
24 | forecast_data = data[-S:]
25 | return train_data, forecast_data
26 |
27 |
28 | def get_chain_num(out_dir):
29 | sleep(rn.random() * 2)
30 |
31 | chain_num_path = out_dir.joinpath('num_chains.txt')
32 |
33 | chain_num = 1
34 | if chain_num_path.exists():
35 | chain_num = int(np.loadtxt(chain_num_path)) + 1
36 | np.savetxt(chain_num_path, np.array([chain_num]))
37 |
38 | return chain_num
39 |
40 |
41 | def rmse(truth, pred):
42 | return np.sqrt(((truth-pred)**2).mean())
43 |
44 |
45 | def mae(truth, pred):
46 | return np.abs(truth-pred).mean()
47 |
48 |
49 | def mre(truth, pred):
50 | return (np.abs(truth - pred) / (truth + 1)).mean()
51 |
52 |
53 | def save_forecast_eval(data_SV, model, eval_path, pred_path=None):
54 | eval_str = ''
55 | if not eval_path.exists():
56 | eval_str = 'ITN\tMAE\tMRE\tRMSE\n'
57 |
58 | S, V = data_SV.shape
59 | pred_SV = model.forecast(n_timesteps=S)
60 |
61 | itn = model.get_total_itns()
62 |
63 | eval_str += '%d\t%f\t%f\t%f\n' % (itn,
64 | mae(data_SV, pred_SV),
65 | mre(data_SV, pred_SV),
66 | rmse(data_SV, pred_SV))
67 | with open(eval_path, 'a+') as f:
68 | f.write(eval_str)
69 |
70 | if pred_path is not None:
71 | np.savez(pred_path, pred_SV=pred_SV)
72 |
73 |
74 | def save_smoothing_eval(masked_data, model, eval_path, pred_path=None):
75 | eval_str = ''
76 | if not eval_path.exists():
77 | eval_str = 'ITN\tMAE\tMRE\tRMSE\n'
78 |
79 | mask_TV = masked_data.mask
80 | assert mask_TV.any()
81 |
82 | data_N = masked_data.data[mask_TV]
83 | pred_N = model.reconstruct(subs=np.where(mask_TV))
84 |
85 | itn = model.get_total_itns()
86 |
87 | eval_str += '%d\t%f\t%f\t%f\n' % (itn,
88 | mae(data_N, pred_N),
89 | mre(data_N, pred_N),
90 | rmse(data_N, pred_N))
91 | with open(eval_path, 'a+') as f:
92 | f.write(eval_str)
93 |
94 | if pred_path is not None:
95 | np.savez(pred_path, pred_N=pred_N)
96 |
97 |
98 | def main():
99 | p = ArgumentParser()
100 | p.add_argument('-d', '--data', type=path, required=True)
101 | p.add_argument('-o', '--out', type=path, required=True)
102 | p.add_argument('-k', '--n_components', type=int, default=100)
103 | p.add_argument('--eps', type=float, default=0.1)
104 | p.add_argument('--tau', type=float, default=1.0)
105 | p.add_argument('--gam', type=float, default=100.0)
106 | p.add_argument('--binary', action="store_true", default=False)
107 | p.add_argument('--stationary', action="store_true", default=False)
108 | p.add_argument('--steady', action="store_true", default=False)
109 | p.add_argument('-s', '--seed', type=int, default=None)
110 | p.add_argument('-v', '--verbose', action="store_true", default=False)
111 | p.add_argument('-n', '--num_itns', type=int, default=6000)
112 | p.add_argument('--save_after', type=int, default=4000)
113 | p.add_argument('--save_every', type=int, default=100)
114 | p.add_argument('--eval_after', type=int, default=4000)
115 | p.add_argument('--eval_every', type=int, default=100)
116 | args = p.parse_args()
117 |
118 | data_dict = np.load(args.data)
119 | data = data_dict['data']
120 | mask = data_dict['mask']
121 | data_TV, data_SV = get_train_forecast_split(data, mask)
122 | T, V = data_TV.shape
123 | if data_SV is not None:
124 | S = data_SV.shape[0]
125 | mask_TV = mask[:T]
126 | else:
127 | S = 0
128 | mask_TV = mask
129 | masked_data = np.ma.array(data_TV, mask=mask_TV)
130 |
131 | args.out.makedirs_p()
132 |
133 | model = PGDS(T=T,
134 | V=V,
135 | K=args.n_components,
136 | eps=args.eps,
137 | gam=args.gam,
138 | tau=args.tau,
139 | stationary=int(args.stationary),
140 | steady=int(args.steady),
141 | binary=int(args.binary),
142 | seed=args.seed)
143 |
144 | burnin = {'Y_KV': 0,
145 | 'Y_TK': 0,
146 | 'L_TKK': 0,
147 | 'H_KK': 50,
148 | 'lnq_K': 50,
149 | 'Phi_KV': 0,
150 | 'delta_T': 0,
151 | 'Pi_KK': 40,
152 | 'nu_K': 50,
153 | 'Theta_TK': 20,
154 | 'beta': 60,
155 | 'xi_K': 70}
156 |
157 | num_itns = args.num_itns
158 | save_after = args.save_after
159 | save_every = args.save_every
160 | eval_after = args.eval_after
161 | eval_every = args.eval_every
162 |
163 | itns = np.arange(num_itns + 1) # include iteration 0
164 |
165 | itns_to_eval = np.zeros(num_itns + 1, dtype=bool) # include iteration 0
166 | if eval_every is not None:
167 | itns_to_eval = itns % eval_every == 0
168 | itns_to_eval[:eval_after] = False # dont save before eval_after
169 | itns_to_eval[0] = True # always evaluate the first
170 | itns_to_eval[-1] = True # always evaluate the last
171 |
172 | itns_to_save = np.zeros(num_itns + 1, dtype=bool) # include iteration 0
173 | if save_every is not None:
174 | itns_to_save = itns % save_every == 0
175 | itns_to_save[:save_after] = False # dont save before save_after
176 | itns_to_save[0] = True # except the first, always save the first
177 | itns_to_save[-1] = True # always save the last
178 |
179 | itns_to_checkpoint = itns_to_eval + itns_to_save
180 | itns_to_checkpoint[0] = True # this ensures num_itns_until_chkpt works
181 | checkpoint_itns = np.where(itns_to_checkpoint)[0]
182 |
183 | num_checkpoints = checkpoint_itns.size
184 | assert num_checkpoints >= 2 # checkpoint_itns will at least be [0,num_itns]
185 | assert checkpoint_itns[-1] == num_itns
186 |
187 | chain = get_chain_num(args.out)
188 | with open(args.out.joinpath('%d_params.p' % chain), 'wb') as params_file:
189 | pickle.dump(model.get_params(), params_file)
190 |
191 | for c, itn in enumerate(checkpoint_itns):
192 | if c == 0:
193 | initialize = True
194 | num_itns_until_chkpt = 0
195 | else:
196 | initialize = False
197 | num_itns_until_chkpt = checkpoint_itns[c] - checkpoint_itns[c-1]
198 |
199 | model.fit(masked_data,
200 | num_itns=num_itns_until_chkpt,
201 | verbose=args.verbose,
202 | initialize=initialize,
203 | burnin=burnin)
204 |
205 | if itns_to_save[itn]:
206 | state_name = '%d_state_%d.npz' % (chain, itn)
207 | np.savez(args.out.joinpath(state_name), **dict(model.get_state()))
208 |
209 | if itns_to_eval[itn]:
210 | eval_path = args.out.joinpath('%d_smoothing_eval.txt' % chain)
211 | pred_path = args.out.joinpath('%d_smoothed_%d.npz' % (chain, itn))
212 | save_smoothing_eval(masked_data, model, eval_path, pred_path)
213 |
214 | if S == 0:
215 | continue
216 | eval_path = args.out.joinpath('%d_forecast_eval.txt' % chain)
217 | pred_path = args.out.joinpath('%d_forecast_%d.npz' % (chain, itn))
218 | save_forecast_eval(data_SV, model, eval_path, pred_path)
219 |
220 | if __name__ == '__main__':
221 | main()
222 |
--------------------------------------------------------------------------------
/src/sample.pxd:
--------------------------------------------------------------------------------
1 | #!python
2 | #cython: boundscheck=False
3 | #cython: cdivision=True
4 | #cython: infertypes=True
5 | #cython: initializedcheck=False
6 | #cython: nonecheck=False
7 | #cython: wraparound=False
8 | #distutils: language = c
9 | #distutils: extra_link_args = ['-lgsl', '-lgslcblas']
10 | #distutils: extra_compile_args = -Wno-unused-function -Wno-unneeded-internal-declaration
11 |
12 | cimport numpy as np
13 | from libc.math cimport log, log1p, exp, M_E, tgamma, sqrt
14 |
15 | cdef extern from "gsl/gsl_rng.h" nogil:
16 | ctypedef struct gsl_rng_type:
17 | pass
18 | ctypedef struct gsl_rng:
19 | pass
20 | gsl_rng_type *gsl_rng_mt19937
21 | gsl_rng *gsl_rng_alloc(gsl_rng_type * T)
22 | void gsl_rng_set(gsl_rng * r, unsigned long int)
23 | void gsl_rng_free(gsl_rng * r)
24 |
25 | cdef extern from "gsl/gsl_randist.h" nogil:
26 | double gsl_rng_uniform(gsl_rng * r)
27 | unsigned int gsl_ran_poisson(gsl_rng * r, double mu)
28 | double gsl_ran_gamma(gsl_rng * r, double a, double b)
29 | unsigned int gsl_ran_logarithmic (const gsl_rng * r, double p)
30 | void gsl_ran_multinomial(gsl_rng * r,
31 | size_t K,
32 | unsigned int N,
33 | const double p[],
34 | unsigned int n[])
35 | double gsl_ran_ugaussian (const gsl_rng * r)
36 |
37 | DEF EPS = 1e-300
38 |
39 | cdef inline double _sample_gamma(gsl_rng * rng, double a, double b) nogil:
40 | """
41 | Simulate a Gamma random variate.
42 |
43 | When a is large, this is a thin wrapper to gsl_ran_gamma.
44 |
45 | When a is small, this calls _sample_ln_gamma_small_shape, and
46 | exponentiates the result.
47 |
48 | This method clips the shape parameter at EPS.
49 | It also clips the result at EPS: the result is always >= EPS.
50 |
51 | Arguments:
52 | rng -- Pointer to a GSL random number generator object
53 | a -- Shape parameter (a > 0)
54 | b -- Scale parameter (b > 0)
55 | """
56 | cdef double g
57 |
58 | if a < EPS:
59 | a = EPS
60 |
61 | if a > 0.75:
62 | g = gsl_ran_gamma(rng, a, b)
63 | else:
64 | g = exp(_sample_lngamma_small_shape(rng, a, b))
65 |
66 | return g if g > EPS else EPS
67 |
68 | cdef inline double _sample_lngamma(gsl_rng * rng, double a, double b) nogil:
69 | """
70 | Simulate a log-Gamma random variate.
71 |
72 | When a is large, this calls gsl_ran_gamma and returns the log of it.
73 |
74 | When a is small, this is a thin wrapper to _sample_ln_gamma_small_shape.
75 |
76 | This method clips the shape parameter at EPS.
77 |
78 | Arguments:
79 | rng -- Pointer to a GSL random number generator object
80 | a -- Shape parameter (a > 0)
81 | b -- Scale parameter (b > 0)
82 | """
83 | cdef double g
84 |
85 | if a < EPS:
86 | a = EPS
87 |
88 | if a > 0.75:
89 | g = gsl_ran_gamma(rng, a, b)
90 | return log(g) if g > EPS else log(EPS)
91 | else:
92 | return _sample_lngamma_small_shape(rng, a, b)
93 |
94 | cdef inline double _sample_lngamma_small_shape(gsl_rng * rng,
95 | double a,
96 | double b) nogil:
97 | """
98 | Implements the algorithm described by Liu, Martin and Syring (2015) for
99 | simulating Gamma random variates with small shape parameters (a < 1).
100 |
101 | Do not call this with a > 1 (the expected number of iterations is large).
102 |
103 | Arguments:
104 | rng -- Pointer to a GSL random number generator object
105 | a -- Shape parameter (a > 0)
106 | b -- Scale parameter (b > 0)
107 |
108 | References:
109 | [1] C. Liu, R. Martin & N. Syring (2013).
110 | Simulating from a gamma distribution with small shape parameter
111 | """
112 | cdef:
113 | double lam, w, lnr, lnu, z, lnh, lneta
114 |
115 | lam = 1. / a - 1.
116 | lnw = log(a) - log1p(-a) - 1
117 | lnr = -log1p(exp(lnw))
118 |
119 | while 1:
120 | lnu = log(gsl_rng_uniform(rng))
121 | if lnu <= lnr:
122 | z = lnr - lnu
123 | else:
124 | z = log(gsl_rng_uniform(rng)) / lam
125 |
126 | lnh = -z - exp(-z / a)
127 | if z >= 0:
128 | lneta = -z
129 | else:
130 | lneta = lnw + log(lam) + lam * z
131 |
132 | if (lnh - lneta) > log(gsl_rng_uniform(rng)):
133 | return log(b) - z / a
134 |
135 |
136 | cdef inline double _sample_lnbeta(gsl_rng * rng, double a, double b) nogil:
137 | """
138 | Sample a log-Beta by sampling 2 log-Gamma random variates.
139 |
140 | Arguments:
141 | rng -- Pointer to a GSL random number generator object
142 | a -- First shape parameter (a > 0)
143 | b -- Second shape parameter (b > 0)
144 | """
145 | cdef:
146 | double lng1, lng2, c, lse
147 |
148 | lng1 = _sample_lngamma(rng, a, 1.)
149 | lng2 = _sample_lngamma(rng, b, 1.)
150 |
151 | c = lng1 if lng1 > lng2 else lng2
152 | lse = c + log(exp(lng1 - c) + exp(lng2 - c)) # logsumexp
153 |
154 | return lng1 - lse
155 |
156 | cdef inline double _sample_beta(gsl_rng * rng, double a, double b) nogil:
157 | """
158 | Sample a Beta by exponentiating a log-Beta.
159 |
160 | Arguments:
161 | rng -- Pointer to a GSL random number generator object
162 | a -- First shape parameter (a > 0)
163 | b -- Second shape parameter (b > 0)
164 | """
165 | return exp(_sample_lnbeta(rng, a, b))
166 |
167 | cdef inline void _sample_lndirichlet(gsl_rng * rng,
168 | double[::1] alpha,
169 | double[::1] out) nogil:
170 | """
171 | Sample a K-dimensional log-Dirichlet by sampling K log-Gamma random variates
172 |
173 | Arguments:
174 | rng -- Pointer to a GSL random number generator object
175 | alpha -- Concentration parameters (alpha[k] > 0 for k = 1...K)
176 | out -- Output array (same size as alpha)
177 | """
178 | cdef:
179 | int K, k
180 | double c, lng, lse
181 |
182 | K = alpha.shape[0]
183 | if out.shape[0] != K:
184 | out[0] = -1
185 | return
186 |
187 | c = out[0] = _sample_lngamma(rng, alpha[0], 1.)
188 | for k in range(1, K):
189 | out[k] = lng = _sample_lngamma(rng, alpha[k], 1.)
190 | if lng > c:
191 | c = lng
192 |
193 | lse = 0 # logsumexp to compute the normalizer
194 | for k in range(K):
195 | lse += exp(out[k] - c)
196 | lse = c + log(lse)
197 |
198 | for k in range(K):
199 | out[k] = out[k] - lse
200 |
201 | cdef inline void _sample_dirichlet(gsl_rng * rng,
202 | double[::1] alpha,
203 | double[::1] out) nogil:
204 | """
205 | Sample a K-dimensional Dirichlet by exponentiating a log-Dirichlet.
206 |
207 | Arguments:
208 | rng -- Pointer to a GSL random number generator object
209 | alpha -- Concentration parameters (alpha[k] > 0 for k = 1...K)
210 | out -- Output array (same size as alpha)
211 | """
212 | cdef:
213 | int K, k
214 |
215 | K = alpha.shape[0]
216 | if out.shape[0] != K:
217 | out[0] = -1
218 | return
219 |
220 | _sample_lndirichlet(rng, alpha, out)
221 |
222 | for k in range(K):
223 | out[k] = exp(out[k])
224 |
225 | cdef inline int _sample_categorical(gsl_rng * rng, double[::1] dist) nogil:
226 | """
227 | Uses the inverse CDF method to return a sample drawn from the
228 | specified (unnormalized) discrete distribution.
229 |
230 | TODO: Use searchsorted to reduce complexity from O(K) to O(logK).
231 | This requires creating a CDF array which requires GIL, if using numpy.
232 |
233 | Arguments:
234 | rng -- Pointer to a GSL random number generator object
235 | dist -- (unnormalized) distribution
236 | """
237 |
238 | cdef:
239 | int k, K
240 | double r
241 |
242 | K = dist.shape[0]
243 |
244 | r = 0.0
245 | for k in range(K):
246 | r += dist[k]
247 |
248 | r *= gsl_rng_uniform(rng)
249 |
250 | for k in range(K):
251 | r -= dist[k]
252 | if r <= 0.0:
253 | return k
254 |
255 | return -1
256 |
257 | cdef inline int _searchsorted(double val, double[::1] arr) nogil:
258 | """
259 | Find first element of a sorted array that is greater than a given value.
260 |
261 | Arguments:
262 | val -- Given value to search for
263 | arr -- Sorted (ascending order) array
264 | """
265 | cdef:
266 | int imin, imax, imid
267 |
268 | imin = 0
269 | imax = arr.shape[0] - 1
270 | while (imin < imax):
271 | imid = (imin + imax) / 2
272 | if arr[imid] < val:
273 | imin = imid + 1
274 | else:
275 | imax = imid
276 | return imin
277 |
278 | cdef inline int _sample_crt(gsl_rng * rng, int m, double r) nogil:
279 | """
280 | Sample a Chinese Restaurant Table (CRT) random variable [1].
281 |
282 | l ~ CRT(m, r) can be sampled as the sum of indep. Bernoullis:
283 |
284 | l = \sum_{n=1}^m Bernoulli(r/(r+n-1))
285 |
286 | where m >= 0 is integer and r >=0 is real.
287 |
288 | Arguments:
289 | rng -- Pointer to a GSL random number generator object
290 | m -- First parameter of the CRT (m >= 0)
291 | r -- Second parameter of the CRT (r >= 0)
292 |
293 | References:
294 | [1] M. Zhou & L. Carin (2012). Negative Binomial Count and Mixture Modeling.
295 | """
296 | cdef:
297 | int l, n
298 | double u, p
299 |
300 | if m < 0 or r < 0:
301 | return -1 # represents an error
302 |
303 | if m == 0:
304 | return 0
305 |
306 | if m == 1:
307 | return 1
308 |
309 | if r <= 1e-50:
310 | return 1
311 |
312 | l = 0
313 | for n in range(m):
314 | p = r / (r + n)
315 | u = gsl_rng_uniform(rng)
316 | if p > u:
317 | l += 1
318 | return l
319 |
320 | cdef inline int _sample_sumcrt(gsl_rng * rng, int[::1] M, double[::1] R) nogil:
321 | """
322 | Sample the sum of K independent CRT random variables.
323 |
324 | l ~ \sum_{k=1}^K CRT(m_k, r_k)
325 |
326 | Arguments:
327 | rng -- Pointer to a GSL random number generator object
328 | M -- Array of first parameters
329 | R -- Array of second parameters (same size as M)
330 | """
331 | cdef:
332 | int l, lk, K, k
333 |
334 | K = M.shape[0]
335 |
336 | l = 0
337 | for k in range(K):
338 | lk = _sample_crt(rng, M[k], R[k])
339 | if lk == -1:
340 | return -1
341 | else:
342 | l += lk
343 |
344 | cdef inline int _sample_sumlog(gsl_rng * rng, int n, double p) nogil:
345 | """
346 | Sample a SumLog random variable defined as the sum of n iid Logarithmic rvs:
347 |
348 | y ~ \sum_{i=1}^n Logarithmic(p)
349 |
350 | Arguments:
351 | rng -- Pointer to a GSL random number generator object
352 | n -- Parameter for number of iid Logarithmic rvs
353 | p -- Probability parameter of the Logarithmic distribution
354 |
355 | TODO: For very large p (e.g., =0.99999) this sometimes fails and returns
356 | negative integers. Figure out why gsl_ran_logarithmic is failing.
357 | """
358 | cdef:
359 | int i, out
360 |
361 | if p <= 0 or p >= 1 or n < 0:
362 | return -1 # this represents an error
363 |
364 | if n == 0:
365 | return 0
366 |
367 | out = 0
368 | for i in range(n):
369 | out += gsl_ran_logarithmic(rng, p)
370 | return out
371 |
372 |
373 | cdef inline int _sample_trunc_poisson(gsl_rng * rng, double mu) nogil:
374 | """
375 | Sample a truncated Poisson random variable as described by Zhou (2015) [1].
376 |
377 | Arguments:
378 | rng -- Pointer to a GSL random number generator object
379 | mu -- Poisson rate parameter
380 |
381 | References:
382 | [1] Zhou, M. (2015). Infinite Edge Partition Models for Overlapping
383 | Community Detection and Link Prediction.
384 | """
385 | cdef:
386 | unsigned int x
387 | double u
388 |
389 | if mu >= 1:
390 | while 1:
391 | x = gsl_ran_poisson(rng, mu)
392 | if x > 0:
393 | return x
394 | else:
395 | while 1:
396 | x = gsl_ran_poisson(rng, mu) + 1
397 | u = gsl_rng_uniform(rng)
398 | if x < 1. / u:
399 | return x
400 |
401 |
402 | cdef inline void _sample_multinomial(gsl_rng * rng,
403 | unsigned int N,
404 | double[::1] p,
405 | unsigned int[::1] out) nogil:
406 | """
407 | Wrapper for gsl_ran_multinomial.
408 | """
409 | cdef:
410 | size_t K
411 |
412 | K = p.shape[0]
413 | gsl_ran_multinomial(rng, K, N, &p[0], &out[0])
414 |
415 |
416 | cdef class Sampler:
417 | """
418 | Wrapper for a gsl_rng object that exposes all sampling methods to Python.
419 |
420 | Useful for testing or writing pure Python programs.
421 | """
422 | cdef:
423 | gsl_rng *rng
424 |
425 | cpdef double gamma(self, double a, double b)
426 | cpdef double lngamma(self, double a, double b)
427 | cpdef double beta(self, double a, double b)
428 | cpdef double lnbeta(self, double a, double b)
429 | cpdef void dirichlet(self, double[::1] alpha, double[::1] out)
430 | cpdef void lndirichlet(self, double[::1] alpha, double[::1] out)
431 | cpdef int categorical(self, double[::1] dist)
432 | cpdef int searchsorted(self, double val, double[::1] arr)
433 | cpdef int crt(self, int m, double r)
434 | cpdef int sumcrt(self, int[::1] M, double[::1] R)
435 | cpdef int sumlog(self, int n, double p)
436 | cpdef int trunc_poisson(self, double mu)
437 | cpdef void multinomial(self,
438 | unsigned int N,
439 | double[::1] p,
440 | unsigned int[::1] out)
441 |
--------------------------------------------------------------------------------
/src/sample.pyx:
--------------------------------------------------------------------------------
1 | #!python
2 | #cython: boundscheck=False
3 | #cython: cdivision=True
4 | #cython: infertypes=True
5 | #cython: initializedcheck=False
6 | #cython: nonecheck=False
7 | #cython: wraparound=False
8 | #distutils: language = c
9 | #distutils: extra_link_args = ['-lgsl', '-lgslcblas']
10 | #distutils: extra_compile_args = -Wno-unused-function -Wno-unneeded-internal-declaration
11 |
12 | import sys
13 | from numpy.random import randint
14 |
15 | cdef class Sampler:
16 | """
17 | Wrapper for a gsl_rng object that exposes all sampling methods to Python.
18 |
19 | Useful for testing or writing pure Python programs.
20 | """
21 | def __init__(self, object seed=None):
22 |
23 | self.rng = gsl_rng_alloc(gsl_rng_mt19937)
24 |
25 | if seed is None:
26 | seed = randint(0, sys.maxint) & 0xFFFFFFFF
27 | gsl_rng_set(self.rng, seed)
28 |
29 | def __dealloc__(self):
30 | """
31 | Free GSL random number generator.
32 | """
33 |
34 | gsl_rng_free(self.rng)
35 |
36 | cpdef double gamma(self, double a, double b):
37 | return _sample_gamma(self.rng, a, b)
38 |
39 | cpdef double lngamma(self, double a, double b):
40 | return _sample_lngamma(self.rng, a, b)
41 |
42 | cpdef double beta(self, double a, double b):
43 | return _sample_beta(self.rng, a, b)
44 |
45 | cpdef double lnbeta(self, double a, double b):
46 | return _sample_lnbeta(self.rng, a, b)
47 |
48 | cpdef void dirichlet(self, double[::1] alpha, double[::1] out):
49 | _sample_dirichlet(self.rng, alpha, out)
50 |
51 | cpdef void lndirichlet(self, double[::1] alpha, double[::1] out):
52 | _sample_lndirichlet(self.rng, alpha, out)
53 |
54 | cpdef int categorical(self, double[::1] dist):
55 | return _sample_categorical(self.rng, dist)
56 |
57 | cpdef int searchsorted(self, double val, double[::1] arr):
58 | return _searchsorted(val, arr)
59 |
60 | cpdef int crt(self, int m, double r):
61 | return _sample_crt(self.rng, m, r)
62 |
63 | cpdef int sumcrt(self, int[::1] M, double[::1] R):
64 | return _sample_sumcrt(self.rng, M, R)
65 |
66 | cpdef int sumlog(self, int n, double p):
67 | return _sample_sumlog(self.rng, n, p)
68 |
69 | cpdef int trunc_poisson(self, double mu):
70 | return _sample_trunc_poisson(self.rng, mu)
71 |
72 | cpdef void multinomial(self, unsigned int N, double[::1] p, unsigned int[::1] out):
73 | _sample_multinomial(self.rng, N, p, out)
74 |
75 |
--------------------------------------------------------------------------------
/src/setup.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env pythonv
2 | import numpy as np
3 | from distutils.core import setup
4 | from distutils.extension import Extension
5 | from Cython.Distutils import build_ext
6 | from Cython.Build import cythonize
7 |
8 |
9 | setup(
10 | cmdclass={'build_ext': build_ext},
11 | include_dirs=[np.get_include(), '../../fatwalrus'],
12 | ext_modules=cythonize('**/*.pyx', include_path=['../../fatwalrus'])
13 | )
14 |
--------------------------------------------------------------------------------
/src/test_pgds.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import numpy.random as rn
3 | import scipy.stats as st
4 |
5 | from pgds import PGDS
6 | from IPython import embed
7 |
8 | if __name__ == '__main__':
9 | V = 3
10 | T = 4
11 | K = 5
12 | eps = 0.75
13 | gam = 30.
14 | tau = 0.75
15 | shrink = 0
16 | stationary = 1
17 | steady = 1
18 | binary = 0
19 |
20 | seed = rn.randint(10000)
21 | print seed
22 |
23 | model = PGDS(T=T, V=V, K=K, eps=eps, gam=gam, tau=tau,
24 | shrink=shrink, stationary=stationary, steady=steady,
25 | binary=binary, seed=seed)
26 |
27 | burnin = {'Y_KV': 0,
28 | 'Y_TK': 0,
29 | 'L_TKK': 0,
30 | 'H_KK': 0,
31 | 'lnq_K': 0,
32 | 'Theta_TK': 0,
33 | 'Phi_KV': 0,
34 | 'delta_T': 0,
35 | 'Pi_KK': 0,
36 | 'nu_K': 0,
37 | 'beta': 0,
38 | 'xi_K': np.inf}
39 |
40 | entropy_funcs = {'Entropy min': lambda x: np.min(st.entropy(x)),
41 | 'Entropy max': lambda x: np.max(st.entropy(x)),
42 | 'Entropy mean': lambda x: np.mean(st.entropy(x)),
43 | 'Entropy var': lambda x: np.var(st.entropy(x))}
44 |
45 | var_funcs = {'Pi_KK': entropy_funcs,
46 | 'Phi_KV': entropy_funcs}
47 |
48 | model.schein(30000, var_funcs=var_funcs, burnin=burnin)
49 | # model.geweke(200000, var_funcs=var_funcs, burnin=burnin)
50 |
--------------------------------------------------------------------------------