├── .gitignore ├── README.md ├── dat └── icews │ └── undirected │ └── 2003-D │ └── data.npz └── src ├── Makefile ├── icews_example.ipynb ├── impute.py ├── lambertw.pxd ├── lambertw.pyx ├── mcmc_model.pxd ├── mcmc_model.pyx ├── pgds.pyx ├── pp_plot.py ├── run_pgds.py ├── sample.pxd ├── sample.pyx ├── setup.py └── test_pgds.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | env/ 12 | build/ 13 | develop-eggs/ 14 | dist/ 15 | downloads/ 16 | eggs/ 17 | .eggs/ 18 | lib/ 19 | lib64/ 20 | parts/ 21 | sdist/ 22 | var/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | 27 | # PyInstaller 28 | # Usually these files are written by a python script from a template 29 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 30 | *.manifest 31 | *.spec 32 | 33 | # Installer logs 34 | pip-log.txt 35 | pip-delete-this-directory.txt 36 | 37 | # Unit test / coverage reports 38 | htmlcov/ 39 | .tox/ 40 | .coverage 41 | .coverage.* 42 | .cache 43 | nosetests.xml 44 | coverage.xml 45 | *,cover 46 | .hypothesis/ 47 | 48 | # Translations 49 | *.mo 50 | *.pot 51 | 52 | # Django stuff: 53 | *.log 54 | local_settings.py 55 | 56 | # Flask stuff: 57 | instance/ 58 | .webassets-cache 59 | 60 | # Scrapy stuff: 61 | .scrapy 62 | 63 | # Sphinx documentation 64 | docs/_build/ 65 | 66 | # PyBuilder 67 | target/ 68 | 69 | # IPython Notebook 70 | .ipynb_checkpoints 71 | 72 | # pyenv 73 | .python-version 74 | 75 | # celery beat schedule file 76 | celerybeat-schedule 77 | 78 | # dotenv 79 | .env 80 | 81 | # virtualenv 82 | venv/ 83 | ENV/ 84 | 85 | # Spyder project settings 86 | .spyderproject 87 | 88 | # Rope project settings 89 | .ropeproject 90 | 91 | ## Custom 92 | src/lambertw.c 93 | src/mcmc_model.c 94 | src/pgds.c 95 | src/sample.c 96 | 97 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Poisson-Gamma Dynamical Systems 2 | Source code for the paper: [Poisson-Gamma Dynamical Systems] (http://people.cs.umass.edu/~aschein/ScheinZhouWallach2016_paper.pdf) by Aaron Schein, Mingyuan Zhou, and Hanna Wallach, presented at NIPS 2016. 3 | 4 | The MIT License (MIT) 5 | 6 | Copyright (c) 2016 Aaron Schein 7 | 8 | Permission is hereby granted, free of charge, to any person obtaining a copy 9 | of this software and associated documentation files (the "Software"), to deal 10 | in the Software without restriction, including without limitation the rights 11 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 12 | copies of the Software, and to permit persons to whom the Software is 13 | furnished to do so, subject to the following conditions: 14 | 15 | The above copyright notice and this permission notice shall be included in all 16 | copies or substantial portions of the Software. 17 | 18 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 19 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 20 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 21 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 22 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 23 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 24 | SOFTWARE. 25 | 26 | ## What's included: 27 | * [pgds.pyx](https://github.com/aschein/pgds/blob/master/src/pgds.pyx): The main code file. Implements Gibbs sampling inference for PGDS. 28 | * [mcmc_model.pyx](https://github.com/aschein/pgds/blob/master/src/mcmc_model.pyx): Implements Cython interface for MCMC models. Inherited by pgds.pyx. 29 | * [sample.pyx](https://github.com/aschein/pgds/blob/master/src/sample.pyx): Implements fast Cython method for sampling various distributions. 30 | * [lambertw.pyx](https://github.com/aschein/pgds/blob/master/src/lambertw.pyx): Code for computing the Lambert-W function. 31 | * [Makefile](https://github.com/aschein/pgds/blob/master/src/Makefile): Makefile (cd into this directoy and type 'make' to compile). 32 | * [icews_example.ipynb](https://github.com/aschein/pgds/blob/master/src/icews_example.ipynb): Jupyter notebook with an examples of how to use the code to run PGDS on ICEWS data for exploratory and predictive analyses. 33 | 34 | ## Dependencies: 35 | * numpy 36 | * scipy 37 | * matplotlib 38 | * seaborn 39 | * pandas 40 | * argparse 41 | * path 42 | * scikit-learn 43 | * cython 44 | * GSL 45 | -------------------------------------------------------------------------------- /dat/icews/undirected/2003-D/data.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aschein/pgds/d2344d10eae1f807379d589ce4ca527b8a4660f5/dat/icews/undirected/2003-D/data.npz -------------------------------------------------------------------------------- /src/Makefile: -------------------------------------------------------------------------------- 1 | all: 2 | python setup.py build_ext -i 3 | 4 | clean: 5 | rm -r build; rm *.c; rm *.cpp; rm *.so; rm *.html; rm *.pyc 6 | -------------------------------------------------------------------------------- /src/icews_example.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 33, 6 | "metadata": { 7 | "collapsed": true 8 | }, 9 | "outputs": [], 10 | "source": [ 11 | "import numpy as np\n", 12 | "import numpy.random as rn\n", 13 | "%matplotlib notebook\n", 14 | "import matplotlib.pyplot as plt\n", 15 | "import seaborn as sns\n", 16 | "import pandas as pd\n", 17 | "from pgds import PGDS" 18 | ] 19 | }, 20 | { 21 | "cell_type": "code", 22 | "execution_count": 34, 23 | "metadata": { 24 | "collapsed": true 25 | }, 26 | "outputs": [], 27 | "source": [ 28 | "data_dict = np.load('../dat/icews/undirected/2003-D/data.npz')" 29 | ] 30 | }, 31 | { 32 | "cell_type": "code", 33 | "execution_count": 35, 34 | "metadata": { 35 | "collapsed": false 36 | }, 37 | "outputs": [ 38 | { 39 | "name": "stdout", 40 | "output_type": "stream", 41 | "text": [ 42 | "T = 365 time steps\n", 43 | "V = 6197 features\n" 44 | ] 45 | } 46 | ], 47 | "source": [ 48 | "Y_TV = data_dict['Y_TV'] # observed TxV count matrix\n", 49 | "(T, V) = Y_TV.shape\n", 50 | "print 'T = %d time steps' % T\n", 51 | "print 'V = %d features' % V" 52 | ] 53 | }, 54 | { 55 | "cell_type": "code", 56 | "execution_count": 36, 57 | "metadata": { 58 | "collapsed": false 59 | }, 60 | "outputs": [ 61 | { 62 | "name": "stdout", 63 | "output_type": "stream", 64 | "text": [ 65 | "First time step: 2003-01-01T00:00:00.000000000\n", 66 | "Last time step: 2003-12-31T00:00:00.000000000\n" 67 | ] 68 | } 69 | ], 70 | "source": [ 71 | "dates_T = data_dict['dates_T'] # time steps are days in 2003\n", 72 | "print 'First time step: %s' % dates_T[0]\n", 73 | "print 'Last time step: %s' % dates_T[-1]" 74 | ] 75 | }, 76 | { 77 | "cell_type": "code", 78 | "execution_count": 37, 79 | "metadata": { 80 | "collapsed": false 81 | }, 82 | "outputs": [ 83 | { 84 | "name": "stdout", 85 | "output_type": "stream", 86 | "text": [ 87 | "Most active feature: Iraq--United States\n", 88 | "Least active feature: Brazil--Uganda\n" 89 | ] 90 | } 91 | ], 92 | "source": [ 93 | "labels_V = data_dict['labels_V'] # features are undirected edges of countries \n", 94 | "print 'Most active feature: %s' % labels_V[0]\n", 95 | "print 'Least active feature: %s' % labels_V[-1]" 96 | ] 97 | }, 98 | { 99 | "cell_type": "markdown", 100 | "metadata": {}, 101 | "source": [ 102 | "# Exploratory analysis" 103 | ] 104 | }, 105 | { 106 | "cell_type": "code", 107 | "execution_count": 38, 108 | "metadata": { 109 | "collapsed": false 110 | }, 111 | "outputs": [], 112 | "source": [ 113 | "K = 100 # number of latent components\n", 114 | "gam = 75 # shrinkage parameter\n", 115 | "tau = 1 # concentration parameter\n", 116 | "eps = 0.1 # uninformative gamma parameter\n", 117 | "stationary = True # stationary variant of the model\n", 118 | "steady = True # use steady state approx. (only for stationary)\n", 119 | "shrink = True # use the shrinkage version\n", 120 | "binary = False # whether the data is binary (vs. counts)\n", 121 | "seed = 111111 # random seed (optional)\n", 122 | "\n", 123 | "model = PGDS(T=T, V=V, K=K, eps=eps, gam=gam, tau=tau,\n", 124 | " stationary=int(stationary), steady=int(steady),\n", 125 | " shrink=int(shrink), binary=int(binary), seed=seed)" 126 | ] 127 | }, 128 | { 129 | "cell_type": "code", 130 | "execution_count": 39, 131 | "metadata": { 132 | "collapsed": false 133 | }, 134 | "outputs": [], 135 | "source": [ 136 | "num_itns = 1000 # number of Gibbs sampling iterations (the more the merrier)\n", 137 | "verbose = False # whether to print out state\n", 138 | "initialize = True # whether to initialize model randomly\n", 139 | "\n", 140 | "model.fit(data=Y_TV,\n", 141 | " num_itns=num_itns,\n", 142 | " verbose=verbose,\n", 143 | " initialize=initialize)" 144 | ] 145 | }, 146 | { 147 | "cell_type": "code", 148 | "execution_count": 40, 149 | "metadata": { 150 | "collapsed": false 151 | }, 152 | "outputs": [], 153 | "source": [ 154 | "state = dict(model.get_state())\n", 155 | "Theta_TK = state['Theta_TK'] # TxK time step factors\n", 156 | "Phi_KV = state['Phi_KV'] # KxV feature factors\n", 157 | "Pi_KK = state['Pi_KK'] # KxK transition matrix\n", 158 | "nu_K = state['nu_K'] # K component weights" 159 | ] 160 | }, 161 | { 162 | "cell_type": "code", 163 | "execution_count": 41, 164 | "metadata": { 165 | "collapsed": false 166 | }, 167 | "outputs": [ 168 | { 169 | "name": "stdout", 170 | "output_type": "stream", 171 | "text": [ 172 | "['Iraq--United States' 'Iraq--United Kingdom'\n", 173 | " 'United Kingdom--United States' 'Turkey--United States'\n", 174 | " 'Russian Federation--United States' 'Iraq--Russian Federation'\n", 175 | " 'Iraq--Turkey' 'France--United States' 'Australia--United States'\n", 176 | " 'South Korea--United States']\n" 177 | ] 178 | } 179 | ], 180 | "source": [ 181 | "top_k = nu_K.argmax() # most active component\n", 182 | "features = Phi_KV[top_k].argsort()[::-1][:10] # top 10 features in top k\n", 183 | "print labels_V[features]" 184 | ] 185 | }, 186 | { 187 | "cell_type": "code", 188 | "execution_count": 42, 189 | "metadata": { 190 | "collapsed": false 191 | }, 192 | "outputs": [ 193 | { 194 | "data": { 195 | "application/javascript": [ 196 | "/* Put everything inside the global mpl namespace */\n", 197 | "window.mpl = {};\n", 198 | "\n", 199 | "mpl.get_websocket_type = function() {\n", 200 | " if (typeof(WebSocket) !== 'undefined') {\n", 201 | " return WebSocket;\n", 202 | " } else if (typeof(MozWebSocket) !== 'undefined') {\n", 203 | " return MozWebSocket;\n", 204 | " } else {\n", 205 | " alert('Your browser does not have WebSocket support.' +\n", 206 | " 'Please try Chrome, Safari or Firefox ≥ 6. ' +\n", 207 | " 'Firefox 4 and 5 are also supported but you ' +\n", 208 | " 'have to enable WebSockets in about:config.');\n", 209 | " };\n", 210 | "}\n", 211 | "\n", 212 | "mpl.figure = function(figure_id, websocket, ondownload, parent_element) {\n", 213 | " this.id = figure_id;\n", 214 | "\n", 215 | " this.ws = websocket;\n", 216 | "\n", 217 | " this.supports_binary = (this.ws.binaryType != undefined);\n", 218 | "\n", 219 | " if (!this.supports_binary) {\n", 220 | " var warnings = document.getElementById(\"mpl-warnings\");\n", 221 | " if (warnings) {\n", 222 | " warnings.style.display = 'block';\n", 223 | " warnings.textContent = (\n", 224 | " \"This browser does not support binary websocket messages. \" +\n", 225 | " \"Performance may be slow.\");\n", 226 | " }\n", 227 | " }\n", 228 | "\n", 229 | " this.imageObj = new Image();\n", 230 | "\n", 231 | " this.context = undefined;\n", 232 | " this.message = undefined;\n", 233 | " this.canvas = undefined;\n", 234 | " this.rubberband_canvas = undefined;\n", 235 | " this.rubberband_context = undefined;\n", 236 | " this.format_dropdown = undefined;\n", 237 | "\n", 238 | " this.image_mode = 'full';\n", 239 | "\n", 240 | " this.root = $('
');\n", 241 | " this._root_extra_style(this.root)\n", 242 | " this.root.attr('style', 'display: inline-block');\n", 243 | "\n", 244 | " $(parent_element).append(this.root);\n", 245 | "\n", 246 | " this._init_header(this);\n", 247 | " this._init_canvas(this);\n", 248 | " this._init_toolbar(this);\n", 249 | "\n", 250 | " var fig = this;\n", 251 | "\n", 252 | " this.waiting = false;\n", 253 | "\n", 254 | " this.ws.onopen = function () {\n", 255 | " fig.send_message(\"supports_binary\", {value: fig.supports_binary});\n", 256 | " fig.send_message(\"send_image_mode\", {});\n", 257 | " fig.send_message(\"refresh\", {});\n", 258 | " }\n", 259 | "\n", 260 | " this.imageObj.onload = function() {\n", 261 | " if (fig.image_mode == 'full') {\n", 262 | " // Full images could contain transparency (where diff images\n", 263 | " // almost always do), so we need to clear the canvas so that\n", 264 | " // there is no ghosting.\n", 265 | " fig.context.clearRect(0, 0, fig.canvas.width, fig.canvas.height);\n", 266 | " }\n", 267 | " fig.context.drawImage(fig.imageObj, 0, 0);\n", 268 | " };\n", 269 | "\n", 270 | " this.imageObj.onunload = function() {\n", 271 | " this.ws.close();\n", 272 | " }\n", 273 | "\n", 274 | " this.ws.onmessage = this._make_on_message_function(this);\n", 275 | "\n", 276 | " this.ondownload = ondownload;\n", 277 | "}\n", 278 | "\n", 279 | "mpl.figure.prototype._init_header = function() {\n", 280 | " var titlebar = $(\n", 281 | " '
');\n", 283 | " var titletext = $(\n", 284 | " '
');\n", 286 | " titlebar.append(titletext)\n", 287 | " this.root.append(titlebar);\n", 288 | " this.header = titletext[0];\n", 289 | "}\n", 290 | "\n", 291 | "\n", 292 | "\n", 293 | "mpl.figure.prototype._canvas_extra_style = function(canvas_div) {\n", 294 | "\n", 295 | "}\n", 296 | "\n", 297 | "\n", 298 | "mpl.figure.prototype._root_extra_style = function(canvas_div) {\n", 299 | "\n", 300 | "}\n", 301 | "\n", 302 | "mpl.figure.prototype._init_canvas = function() {\n", 303 | " var fig = this;\n", 304 | "\n", 305 | " var canvas_div = $('
');\n", 306 | "\n", 307 | " canvas_div.attr('style', 'position: relative; clear: both; outline: 0');\n", 308 | "\n", 309 | " function canvas_keyboard_event(event) {\n", 310 | " return fig.key_event(event, event['data']);\n", 311 | " }\n", 312 | "\n", 313 | " canvas_div.keydown('key_press', canvas_keyboard_event);\n", 314 | " canvas_div.keyup('key_release', canvas_keyboard_event);\n", 315 | " this.canvas_div = canvas_div\n", 316 | " this._canvas_extra_style(canvas_div)\n", 317 | " this.root.append(canvas_div);\n", 318 | "\n", 319 | " var canvas = $('');\n", 320 | " canvas.addClass('mpl-canvas');\n", 321 | " canvas.attr('style', \"left: 0; top: 0; z-index: 0; outline: 0\")\n", 322 | "\n", 323 | " this.canvas = canvas[0];\n", 324 | " this.context = canvas[0].getContext(\"2d\");\n", 325 | "\n", 326 | " var rubberband = $('');\n", 327 | " rubberband.attr('style', \"position: absolute; left: 0; top: 0; z-index: 1;\")\n", 328 | "\n", 329 | " var pass_mouse_events = true;\n", 330 | "\n", 331 | " canvas_div.resizable({\n", 332 | " start: function(event, ui) {\n", 333 | " pass_mouse_events = false;\n", 334 | " },\n", 335 | " resize: function(event, ui) {\n", 336 | " fig.request_resize(ui.size.width, ui.size.height);\n", 337 | " },\n", 338 | " stop: function(event, ui) {\n", 339 | " pass_mouse_events = true;\n", 340 | " fig.request_resize(ui.size.width, ui.size.height);\n", 341 | " },\n", 342 | " });\n", 343 | "\n", 344 | " function mouse_event_fn(event) {\n", 345 | " if (pass_mouse_events)\n", 346 | " return fig.mouse_event(event, event['data']);\n", 347 | " }\n", 348 | "\n", 349 | " rubberband.mousedown('button_press', mouse_event_fn);\n", 350 | " rubberband.mouseup('button_release', mouse_event_fn);\n", 351 | " // Throttle sequential mouse events to 1 every 20ms.\n", 352 | " rubberband.mousemove('motion_notify', mouse_event_fn);\n", 353 | "\n", 354 | " rubberband.mouseenter('figure_enter', mouse_event_fn);\n", 355 | " rubberband.mouseleave('figure_leave', mouse_event_fn);\n", 356 | "\n", 357 | " canvas_div.on(\"wheel\", function (event) {\n", 358 | " event = event.originalEvent;\n", 359 | " event['data'] = 'scroll'\n", 360 | " if (event.deltaY < 0) {\n", 361 | " event.step = 1;\n", 362 | " } else {\n", 363 | " event.step = -1;\n", 364 | " }\n", 365 | " mouse_event_fn(event);\n", 366 | " });\n", 367 | "\n", 368 | " canvas_div.append(canvas);\n", 369 | " canvas_div.append(rubberband);\n", 370 | "\n", 371 | " this.rubberband = rubberband;\n", 372 | " this.rubberband_canvas = rubberband[0];\n", 373 | " this.rubberband_context = rubberband[0].getContext(\"2d\");\n", 374 | " this.rubberband_context.strokeStyle = \"#000000\";\n", 375 | "\n", 376 | " this._resize_canvas = function(width, height) {\n", 377 | " // Keep the size of the canvas, canvas container, and rubber band\n", 378 | " // canvas in synch.\n", 379 | " canvas_div.css('width', width)\n", 380 | " canvas_div.css('height', height)\n", 381 | "\n", 382 | " canvas.attr('width', width);\n", 383 | " canvas.attr('height', height);\n", 384 | "\n", 385 | " rubberband.attr('width', width);\n", 386 | " rubberband.attr('height', height);\n", 387 | " }\n", 388 | "\n", 389 | " // Set the figure to an initial 600x600px, this will subsequently be updated\n", 390 | " // upon first draw.\n", 391 | " this._resize_canvas(600, 600);\n", 392 | "\n", 393 | " // Disable right mouse context menu.\n", 394 | " $(this.rubberband_canvas).bind(\"contextmenu\",function(e){\n", 395 | " return false;\n", 396 | " });\n", 397 | "\n", 398 | " function set_focus () {\n", 399 | " canvas.focus();\n", 400 | " canvas_div.focus();\n", 401 | " }\n", 402 | "\n", 403 | " window.setTimeout(set_focus, 100);\n", 404 | "}\n", 405 | "\n", 406 | "mpl.figure.prototype._init_toolbar = function() {\n", 407 | " var fig = this;\n", 408 | "\n", 409 | " var nav_element = $('
')\n", 410 | " nav_element.attr('style', 'width: 100%');\n", 411 | " this.root.append(nav_element);\n", 412 | "\n", 413 | " // Define a callback function for later on.\n", 414 | " function toolbar_event(event) {\n", 415 | " return fig.toolbar_button_onclick(event['data']);\n", 416 | " }\n", 417 | " function toolbar_mouse_event(event) {\n", 418 | " return fig.toolbar_button_onmouseover(event['data']);\n", 419 | " }\n", 420 | "\n", 421 | " for(var toolbar_ind in mpl.toolbar_items) {\n", 422 | " var name = mpl.toolbar_items[toolbar_ind][0];\n", 423 | " var tooltip = mpl.toolbar_items[toolbar_ind][1];\n", 424 | " var image = mpl.toolbar_items[toolbar_ind][2];\n", 425 | " var method_name = mpl.toolbar_items[toolbar_ind][3];\n", 426 | "\n", 427 | " if (!name) {\n", 428 | " // put a spacer in here.\n", 429 | " continue;\n", 430 | " }\n", 431 | " var button = $('