├── data ├── ebola │ ├── rstb20160308supp1.xlsx │ ├── rstb20160308_si_002.xlsx │ ├── shapefiles │ │ └── gin-lbr-sle │ │ │ ├── gin-lbr-sle.dbf │ │ │ ├── gin-lbr-sle.shp │ │ │ ├── gin-lbr-sle.shx │ │ │ ├── gin-lbr-sle.prj │ │ │ ├── gin-lbr-sle.qpj │ │ │ └── gin-lbr-sle.gal │ ├── ebola_populations.csv │ ├── ebola_pos_list.csv │ ├── ebola_net_edge_list.csv │ ├── district_adjacency.txt │ └── ebola_base_graph.json └── contiguous-usa.txt ├── lib ├── settings.py ├── metrics.py ├── utils.py ├── graph_generation.py ├── ogata │ ├── main_evaluate.py │ ├── main_simulate.py │ ├── stochastic_processes.py │ ├── helpers.py │ ├── visualize_OPT.py │ ├── experiment.py │ └── analysis.py ├── poisson_renewal.py ├── maxcut.py └── dynamics.py ├── README.md ├── setup.py ├── notebooks ├── Run - Validate Epidemic parameters for Ebola dataset.ipynb ├── 2-calibration │ ├── 2-baseline-calibration.py │ ├── script_baseline_calibration_soc.py │ └── 2-baseline-calibration.ipynb ├── 0-tutorial │ ├── 2-network-generation.ipynb │ └── 1-SIR-simulation.ipynb ├── Profile - Experiment.ipynb ├── Evaluation - multi (old code version).ipynb └── 1-preprocessing │ └── Debug - dynamics_ind.SimulationSIR.ipynb ├── requirements.txt ├── .gitignore ├── script_many_jobs.py └── script_single_job.py /data/ebola/rstb20160308supp1.xlsx: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Networks-Learning/disease-control/HEAD/data/ebola/rstb20160308supp1.xlsx -------------------------------------------------------------------------------- /data/ebola/rstb20160308_si_002.xlsx: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Networks-Learning/disease-control/HEAD/data/ebola/rstb20160308_si_002.xlsx -------------------------------------------------------------------------------- /data/ebola/shapefiles/gin-lbr-sle/gin-lbr-sle.dbf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Networks-Learning/disease-control/HEAD/data/ebola/shapefiles/gin-lbr-sle/gin-lbr-sle.dbf -------------------------------------------------------------------------------- /data/ebola/shapefiles/gin-lbr-sle/gin-lbr-sle.shp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Networks-Learning/disease-control/HEAD/data/ebola/shapefiles/gin-lbr-sle/gin-lbr-sle.shp -------------------------------------------------------------------------------- /data/ebola/shapefiles/gin-lbr-sle/gin-lbr-sle.shx: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Networks-Learning/disease-control/HEAD/data/ebola/shapefiles/gin-lbr-sle/gin-lbr-sle.shx -------------------------------------------------------------------------------- /data/ebola/shapefiles/gin-lbr-sle/gin-lbr-sle.prj: -------------------------------------------------------------------------------- 1 | GEOGCS["GCS_WGS_1984",DATUM["D_WGS_1984",SPHEROID["WGS_1984",6378137,298.257223563]],PRIMEM["Greenwich",0],UNIT["Degree",0.017453292519943295]] -------------------------------------------------------------------------------- /data/ebola/shapefiles/gin-lbr-sle/gin-lbr-sle.qpj: -------------------------------------------------------------------------------- 1 | GEOGCS["WGS 84",DATUM["WGS_1984",SPHEROID["WGS 84",6378137,298.257223563,AUTHORITY["EPSG","7030"]],AUTHORITY["EPSG","6326"]],PRIMEM["Greenwich",0,AUTHORITY["EPSG","8901"]],UNIT["degree",0.0174532925199433,AUTHORITY["EPSG","9122"]],AUTHORITY["EPSG","4326"]] 2 | -------------------------------------------------------------------------------- /lib/settings.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | PROJECT_DIR = os.path.abspath(os.path.join(os.path.dirname(os.path.realpath(__file__)), '..')) 4 | DATA_DIR = os.path.join(PROJECT_DIR, 'data') 5 | 6 | EBOLA_BASE_GRAPH_FILE = os.path.join(DATA_DIR, 'ebola', 'ebola_base_graph.json') 7 | EBOLA_SCALED_GRAPH_FILE = os.path.join(DATA_DIR, 'ebola', 'ebola_scaled_graph.json') 8 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # disease-control 2 | 3 | This is the code for the simulation of the project on disease control with SDEs with jumps. 4 | The following describes the different files and how they interact. 5 | 6 | ## Installation 7 | 8 | To run the experiment notebooks, first install the internal library using: 9 | 10 | ``` 11 | python setup.py install -e . 12 | ``` 13 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | 3 | setup( 4 | name='disease-control', 5 | version='0.1', 6 | url='https://github.com/Networks-Learning/disease-control.git', 7 | description='Optimal Stochastic Control algorithm for epidemic models', 8 | packages=find_packages(), 9 | install_requires=['numpy >= 1.16.2', 'networkx >= 2.0', 'scipy >= 1.2.1', 10 | 'lpsolvers >= 0.8.9', 'pandas >= 0.24.1'], 11 | ) 12 | -------------------------------------------------------------------------------- /notebooks/Run - Validate Epidemic parameters for Ebola dataset.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [] 9 | } 10 | ], 11 | "metadata": { 12 | "kernelspec": { 13 | "display_name": "Python 3", 14 | "language": "python", 15 | "name": "python3" 16 | }, 17 | "language_info": { 18 | "codemirror_mode": { 19 | "name": "ipython", 20 | "version": 3 21 | }, 22 | "file_extension": ".py", 23 | "mimetype": "text/x-python", 24 | "name": "python", 25 | "nbconvert_exporter": "python", 26 | "pygments_lexer": "ipython3", 27 | "version": "3.6.3" 28 | } 29 | }, 30 | "nbformat": 4, 31 | "nbformat_minor": 2 32 | } 33 | -------------------------------------------------------------------------------- /lib/metrics.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | 4 | from lib.poisson_renewal import ModelPoissonRenewal, Learner 5 | 6 | 7 | def wasserstein_distance(t1, t2): 8 | n = min(len(t1), len(t2)) 9 | m = max(len(t1), len(t2)) 10 | if len(t1) == 0: 11 | T = max(t2) 12 | elif len(t2) == 0: 13 | T = max(t1) 14 | else: 15 | T = max(max(t1), max(t2)) 16 | val = sum(abs(t1[:n] - t2[:n])) + (m - n) * T - sum(t1[n:]) - sum(t2[n:]) 17 | return val 18 | 19 | 20 | def estimate_reproduction_number(daily_count_arr, beta, T): 21 | model = ModelPoissonRenewal() 22 | learner = Learner(model, lr=0.01, lr_gamma=1.0, tol=1e-6, max_iter=10000) 23 | log_r_init = torch.tensor([0.0], dtype=torch.float64) 24 | log_r_hat = learner.fit(log_r_init, daily_count_arr, beta, T) 25 | log_r_hat = log_r_hat.detach().numpy()[0] 26 | r_hat = math.exp(log_r_hat) 27 | return r_hat 28 | -------------------------------------------------------------------------------- /data/ebola/ebola_populations.csv: -------------------------------------------------------------------------------- 1 | district,population 2 | BEYLA,248143 3 | BOFFA,217743 4 | BOKE,526569 5 | CONAKRY,1729239 6 | COYAH,407975 7 | DABOLA,181413 8 | DALABA,212443 9 | DINGUIRAYE,202818 10 | DUBREKA,200255 11 | FARANAH,218780 12 | FORECARIAH,418767 13 | FRIA,142252 14 | GUECKEDOU,801976 15 | KANKAN,446082 16 | KEROUANE,291859 17 | KINDIA,552940 18 | KISSIDOUGO,329839 19 | KOUROUSSA,219639 20 | LOLA,247806 21 | MACENTA,553487 22 | MALI,237519 23 | NZEREKORE,445936 24 | PITA,301014 25 | SIGUIRI,483111 26 | TELIMELE,294313 27 | TOUGUE,181277 28 | YOMOU,327652 29 | BOMI,145557 30 | BONG,423682 31 | GBARPOLU,111770 32 | GRAND_BASSA,260551 33 | GRAND_CAPE_MOUNT,164466 34 | GRAND_GEDEH,164629 35 | GRAND_KRU,75021 36 | LOFA,371662 37 | MARGIBI,344143 38 | MARYLAND,145228 39 | MONTSERRADO,1174503 40 | NIMBA,631097 41 | RIVER_GEE,102026 42 | RIVERCESS,88858 43 | SINOE,124603 44 | BO,582706 45 | BOMBALI,504504 46 | BONTHE,157694 47 | KAILAHUN,404088 48 | KAMBIA,305849 49 | KENEMA,657056 50 | KOINADUGU,325595 51 | KONO,490092 52 | MOYAMBA,324082 53 | PORT_LOKO,559287 54 | PUJEHUN,279316 55 | TONKOLILI,439166 56 | WESTERN,1128605 -------------------------------------------------------------------------------- /data/contiguous-usa.txt: -------------------------------------------------------------------------------- 1 | 1 2 2 | 1 3 3 | 1 4 4 | 1 5 5 | 6 7 6 | 6 8 7 | 6 4 8 | 6 9 9 | 6 5 10 | 6 10 11 | 11 12 12 | 11 13 13 | 11 14 14 | 11 15 15 | 12 14 16 | 12 16 17 | 17 18 18 | 17 19 19 | 17 13 20 | 17 9 21 | 17 15 22 | 17 20 23 | 21 22 24 | 21 23 25 | 21 24 26 | 25 26 27 | 25 27 28 | 28 26 29 | 28 29 30 | 28 30 31 | 2 3 32 | 3 31 33 | 3 32 34 | 3 5 35 | 33 34 36 | 33 35 37 | 33 8 38 | 33 19 39 | 33 36 40 | 33 37 41 | 38 39 42 | 38 14 43 | 38 16 44 | 38 15 45 | 38 40 46 | 38 20 47 | 34 41 48 | 34 42 49 | 34 8 50 | 34 37 51 | 41 42 52 | 41 43 53 | 41 44 54 | 18 8 55 | 18 19 56 | 18 9 57 | 42 8 58 | 42 44 59 | 42 5 60 | 42 27 61 | 42 45 62 | 7 4 63 | 7 10 64 | 22 46 65 | 22 23 66 | 22 24 67 | 22 47 68 | 26 30 69 | 26 27 70 | 26 45 71 | 48 46 72 | 43 44 73 | 43 37 74 | 35 49 75 | 35 36 76 | 35 37 77 | 8 19 78 | 8 9 79 | 8 5 80 | 4 5 81 | 39 49 82 | 39 36 83 | 39 20 84 | 31 32 85 | 31 5 86 | 31 27 87 | 49 36 88 | 19 36 89 | 19 20 90 | 46 47 91 | 29 23 92 | 29 30 93 | 13 9 94 | 13 10 95 | 14 16 96 | 14 15 97 | 23 30 98 | 23 47 99 | 44 30 100 | 44 45 101 | 9 10 102 | 16 40 103 | 30 45 104 | 36 20 105 | 5 27 106 | 15 20 107 | 27 45 108 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | appnope==0.1.0 2 | attrs==18.2.0 3 | backcall==0.1.0 4 | bleach==3.3.0 5 | branca==0.3.1 6 | certifi==2019.3.9 7 | chardet==3.0.4 8 | cycler==0.10.0 9 | decorator==4.3.2 10 | defusedxml==0.5.0 11 | entrypoints==0.3 12 | folium==0.8.3 13 | idna==2.8 14 | ipykernel==5.1.0 15 | ipython==7.16.3 16 | ipython-genutils==0.2.0 17 | ipywidgets==7.4.2 18 | jedi==0.13.3 19 | Jinja2==2.11.3 20 | joblib==0.13.2 21 | jsonschema==3.0.0 22 | jupyter==1.0.0 23 | jupyter-client==5.2.4 24 | jupyter-console==6.0.0 25 | jupyter-core==4.4.0 26 | kiwisolver==1.0.1 27 | MarkupSafe==1.1.1 28 | matplotlib==3.0.2 29 | mistune==0.8.4 30 | nbconvert==5.4.1 31 | nbformat==4.4.0 32 | networkx==2.2 33 | notebook==6.4.12 34 | numpy==1.22.0 35 | pandas==0.24.1 36 | pandocfilters==1.4.2 37 | parso==0.3.4 38 | pexpect==4.6.0 39 | pickleshare==0.7.5 40 | prometheus-client==0.6.0 41 | prompt-toolkit==2.0.9 42 | ptyprocess==0.6.0 43 | Pygments==2.7.4 44 | pyparsing==2.3.1 45 | pyrsistent==0.14.11 46 | python-dateutil==2.8.0 47 | pytz==2018.9 48 | pyzmq==18.0.0 49 | qtconsole==4.4.3 50 | requests==2.21.0 51 | scipy==1.2.1 52 | Send2Trash==1.5.0 53 | six==1.12.0 54 | terminado==0.8.1 55 | testpath==0.4.2 56 | tornado==5.1.1 57 | tqdm==4.31.1 58 | traitlets==4.3.2 59 | urllib3==1.26.5 60 | wcwidth==0.1.7 61 | webencodings==0.5.1 62 | widgetsnbextension==3.4.2 63 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Ignore custom folders 2 | # data/ 3 | output/ 4 | 5 | temp_pickles/ 6 | plots/ 7 | graphs/ 8 | *.DS_Store 9 | .DS_Store 10 | 11 | 12 | # Byte-compiled / optimized / DLL files 13 | __pycache__/ 14 | *.py[cod] 15 | *$py.class 16 | 17 | # C extensions 18 | *.so 19 | 20 | # Distribution / packaging 21 | .Python 22 | build/ 23 | develop-eggs/ 24 | dist/ 25 | downloads/ 26 | eggs/ 27 | .eggs/ 28 | #lib/ 29 | lib64/ 30 | parts/ 31 | sdist/ 32 | var/ 33 | wheels/ 34 | *.egg-info/ 35 | .installed.cfg 36 | *.egg 37 | MANIFEST 38 | 39 | # PyInstaller 40 | # Usually these files are written by a python script from a template 41 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 42 | *.manifest 43 | *.spec 44 | 45 | # Installer logs 46 | pip-log.txt 47 | pip-delete-this-directory.txt 48 | 49 | # Unit test / coverage reports 50 | htmlcov/ 51 | .tox/ 52 | .nox/ 53 | .coverage 54 | .coverage.* 55 | .cache 56 | nosetests.xml 57 | coverage.xml 58 | *.cover 59 | .hypothesis/ 60 | .pytest_cache/ 61 | 62 | # Translations 63 | *.mo 64 | *.pot 65 | 66 | # Django stuff: 67 | *.log 68 | local_settings.py 69 | db.sqlite3 70 | 71 | # Flask stuff: 72 | instance/ 73 | .webassets-cache 74 | 75 | # Scrapy stuff: 76 | .scrapy 77 | 78 | # Sphinx documentation 79 | docs/_build/ 80 | 81 | # PyBuilder 82 | target/ 83 | 84 | # Jupyter Notebook 85 | .ipynb_checkpoints 86 | 87 | # IPython 88 | profile_default/ 89 | ipython_config.py 90 | 91 | # pyenv 92 | .python-version 93 | 94 | # celery beat schedule file 95 | celerybeat-schedule 96 | 97 | # SageMath parsed files 98 | *.sage.py 99 | 100 | # Environments 101 | .env 102 | .venv 103 | env/ 104 | venv/ 105 | ENV/ 106 | env.bak/ 107 | venv.bak/ 108 | 109 | # Spyder project settings 110 | .spyderproject 111 | .spyproject 112 | 113 | # Rope project settings 114 | .ropeproject 115 | 116 | # mkdocs documentation 117 | /site 118 | 119 | # mypy 120 | .mypy_cache/ 121 | .dmypy.json 122 | dmypy.json 123 | -------------------------------------------------------------------------------- /lib/utils.py: -------------------------------------------------------------------------------- 1 | from collections import defaultdict, Counter 2 | import numpy as np 3 | 4 | 5 | def compute_infection_time_per_district(inf_arr, district_arr): 6 | """ 7 | Return a dict keyed by district contained in node metadata of the graph `sir.G`. Each value 8 | is a list of all sorted infection times in the district. 9 | """ 10 | district_inf_times = defaultdict(list) 11 | # Iterate over infection times of each node in the simulation 12 | for i, (inf_time, district) in enumerate(zip(inf_arr, district_arr)): 13 | if (inf_time > 0) and (inf_time < np.inf): 14 | district_inf_times[district].append(inf_time) 15 | # Format as dict and sort values 16 | district_inf_times = dict(district_inf_times) 17 | for k, v in district_inf_times.items(): 18 | district_inf_times[k] = np.array(sorted(v)) 19 | return district_inf_times 20 | 21 | 22 | def compute_r0_per_country(inf_time_arr, infector_arr, country_arr): 23 | """ 24 | Compute the basic reproduction number R0 per country for the given data. 25 | 26 | inf_time_arr : array-like (shape: (N,)) 27 | Infection time of each of the N nodes 28 | infector_arr : array-like (shape: (N,)) 29 | Index of the 30 | """ 31 | # Count of secondary infections: {node_idx: num of infected neighbors} 32 | infector_count = Counter(infector_arr) 33 | # Indices of infected nodes 34 | infected_node_indices = np.where(np.array(inf_time_arr) < np.inf)[0] 35 | # Initialize the list of number of secondary infections per country 36 | country_count = {country: list() for country in set(country_arr)} 37 | # For each infected node, add its number of secondary case to its country 38 | for u_idx in infected_node_indices: 39 | u_country = country_arr[u_idx] 40 | inf_count = infector_count[u_idx] 41 | country_count[u_country].append(inf_count) 42 | country_count = dict(country_count) 43 | # Compute R0 as the mean number of secondary case for each country 44 | countru_r0_dict = {country: np.mean(count) if len(count) > 0 else 0.0 for country, count in country_count.items()} 45 | return countru_r0_dict 46 | -------------------------------------------------------------------------------- /data/ebola/ebola_pos_list.csv: -------------------------------------------------------------------------------- 1 | country,district,latitude,longitude 2 | Guinea,BEYLA,8.683333,-8.633333 3 | Sierra Leone,BO,7.9552,-11.471 4 | Guinea,BOFFA,10.180825,-14.039161 5 | Guinea,BOKE,11.186467,-14.100133 6 | Sierra Leone,BOMBALI,9.247584,-12.163272 7 | Liberia,BOMI,6.756293,-10.845147 8 | Liberia,BONG,6.829502,-9.367308 9 | Sierra Leone,BONTHE,7.525703,-12.503992 10 | Guinea,CONAKRY,9.641185,-13.578401 11 | Guinea,COYAH,9.708636,-13.387612 12 | Guinea,DABOLA,10.729781,-11.110785 13 | Guinea,DALABA,10.686818,-12.24907 14 | Guinea,DINGUIRAYE,11.289951,-10.715423 15 | Guinea,DUBREKA,9.790735,-13.514774 16 | Guinea,FARANAH,10.045102,-10.749247 17 | Guinea,FORECARIAH,9.434471,-13.090435 18 | Guinea,FRIA,10.367454,-13.584187 19 | Liberia,GBARPOLU,7.495264,-10.08073 20 | Liberia,GRAND BASSA,6.230845,-9.812493 21 | Liberia,GRAND CAPE MOUNT,7.046776,-11.071176 22 | Liberia,GRAND GEDEH,5.922208,-8.221298 23 | Liberia,GRAND KRU,4.761386,-8.221298 24 | Guinea,GUECKEDOU,8.564969,-10.131116 25 | Sierra Leone,KAILAHUN,8.28022,-10.571809 26 | Sierra Leone,KAMBIA,9.126166,-12.917652 27 | Guinea,KANKAN,10.382789,-9.311828 28 | Sierra Leone,KENEMA,7.863215,-11.195717 29 | Guinea,KEROUANE,9.27026,-9.007367 30 | Guinea,KINDIA,10.040672,-12.862989 31 | Guinea,KISSIDOUGO,9.191454,-10.114318 32 | Sierra Leone,KOINADUGU,9.530862,-11.524805 33 | Sierra Leone,KONO,8.766329,-10.89031 34 | Guinea,KOUROUSSA,10.648923,-9.885059 35 | Liberia,LOFA,8.191118,-9.723267 36 | Guinea,LOLA,7.802235,-8.533653 37 | Guinea,MACENTA,8.538294,-9.472824 38 | Guinea,MALI,12.074294,-12.297718 39 | Liberia,MARGIBI,6.515187,-10.30489 40 | Liberia,MARYLAND,4.725888,-7.74167 41 | Liberia,MONTSERRADO,6.552581,-10.529611 42 | Sierra Leone,MOYAMBA,8.162051,-12.435192 43 | Guinea,N'ZEREKORE,7.747836,-8.82525 44 | Liberia,NIMBA,6.842761,-8.660059 45 | Guinea,PITA,11.057462,-12.397943 46 | Sierra Leone,PORT LOKO,8.768689,-12.785352 47 | Sierra Leone,PUJEHUN,7.356299,-11.721064 48 | Liberia,RIVER GEE,5.260489,-7.87216 49 | Liberia,RIVERCESS,5.902533,-9.456155 50 | Guinea,SIGUIRI,11.414811,-9.17883 51 | Liberia,SINOE,5.49871,-8.660059 52 | Guinea,TELIMELE,10.908936,-13.029933 53 | Sierra Leone,TONKOLILI,8.738942,-11.797961 54 | Guinea,TOUGUE,11.446422,-11.664139 55 | Sierra Leone,WESTERN,8.311498,-13.035694 56 | Guinea,YOMOU,7.569628,-9.259157 57 | -------------------------------------------------------------------------------- /data/ebola/ebola_net_edge_list.csv: -------------------------------------------------------------------------------- 1 | BEYLA KANKAN 2 | BEYLA KEROUANE 3 | BEYLA LOLA 4 | BEYLA MACENTA 5 | BEYLA NZEREKORE 6 | KANKAN KEROUANE 7 | KANKAN KISSIDOUGO 8 | KANKAN KOUROUSSA 9 | KANKAN SIGUIRI 10 | KEROUANE KISSIDOUGO 11 | KEROUANE MACENTA 12 | LOLA NZEREKORE 13 | LOLA NIMBA 14 | MACENTA GUECKEDOU 15 | MACENTA KISSIDOUGO 16 | MACENTA NZEREKORE 17 | MACENTA YOMOU 18 | MACENTA LOFA 19 | NZEREKORE YOMOU 20 | NZEREKORE NIMBA 21 | KISSIDOUGO FARANAH 22 | KISSIDOUGO KOUROUSSA 23 | KISSIDOUGO GUECKEDOU 24 | KOUROUSSA DABOLA 25 | KOUROUSSA DINGUIRAYE 26 | KOUROUSSA FARANAH 27 | KOUROUSSA SIGUIRI 28 | SIGUIRI DINGUIRAYE 29 | NIMBA YOMOU 30 | NIMBA BONG 31 | NIMBA GRAND_BASSA 32 | NIMBA GRAND_GEDEH 33 | NIMBA RIVERCESS 34 | NIMBA SINOE 35 | GUECKEDOU FARANAH 36 | GUECKEDOU LOFA 37 | GUECKEDOU KAILAHUN 38 | GUECKEDOU KOINADUGU 39 | GUECKEDOU KONO 40 | YOMOU LOFA 41 | YOMOU BONG 42 | LOFA BONG 43 | LOFA GBARPOLU 44 | LOFA KAILAHUN 45 | BOFFA BOKE 46 | BOFFA DUBREKA 47 | BOFFA FRIA 48 | BOFFA TELIMELE 49 | BOKE TELIMELE 50 | DUBREKA CONAKRY 51 | DUBREKA COYAH 52 | DUBREKA FRIA 53 | DUBREKA KINDIA 54 | DUBREKA TELIMELE 55 | FRIA TELIMELE 56 | TELIMELE KINDIA 57 | TELIMELE PITA 58 | CONAKRY COYAH 59 | COYAH FORECARIAH 60 | COYAH KINDIA 61 | KINDIA FORECARIAH 62 | KINDIA DALABA 63 | KINDIA PITA 64 | KINDIA BOMBALI 65 | PITA DALABA 66 | PITA MALI 67 | FORECARIAH BOMBALI 68 | FORECARIAH KAMBIA 69 | BOMBALI KOINADUGU 70 | BOMBALI KAMBIA 71 | BOMBALI PORT_LOKO 72 | BOMBALI TONKOLILI 73 | KAMBIA PORT_LOKO 74 | DALABA DABOLA 75 | DALABA FARANAH 76 | DALABA TOUGUE 77 | DABOLA DINGUIRAYE 78 | DABOLA FARANAH 79 | DINGUIRAYE TOUGUE 80 | FARANAH KOINADUGU 81 | TOUGUE MALI 82 | KOINADUGU KONO 83 | KOINADUGU TONKOLILI 84 | KAILAHUN GBARPOLU 85 | KAILAHUN KENEMA 86 | KAILAHUN KONO 87 | KONO KENEMA 88 | KONO TONKOLILI 89 | TONKOLILI KENEMA 90 | TONKOLILI BO 91 | TONKOLILI MOYAMBA 92 | TONKOLILI PORT_LOKO 93 | PORT_LOKO MOYAMBA 94 | PORT_LOKO WESTERN 95 | BONG BOMI 96 | BONG GBARPOLU 97 | BONG GRAND_BASSA 98 | BONG MARGIBI 99 | BONG MONTSERRADO 100 | GBARPOLU BOMI 101 | GBARPOLU GRAND_CAPE_MOUNT 102 | GBARPOLU MONTSERRADO 103 | GBARPOLU KENEMA 104 | KENEMA GRAND_CAPE_MOUNT 105 | KENEMA BO 106 | KENEMA PUJEHUN 107 | GRAND_BASSA MARGIBI 108 | GRAND_BASSA RIVERCESS 109 | GRAND_GEDEH RIVERCESS 110 | GRAND_GEDEH RIVER_GEE 111 | GRAND_GEDEH SINOE 112 | RIVERCESS SINOE 113 | SINOE RIVER_GEE 114 | SINOE GRAND_KRU 115 | BOMI GRAND_CAPE_MOUNT 116 | BOMI MONTSERRADO 117 | MARGIBI MONTSERRADO 118 | GRAND_CAPE_MOUNT PUJEHUN 119 | PUJEHUN BO 120 | PUJEHUN BONTHE 121 | BO BONTHE 122 | BO MOYAMBA 123 | BONTHE MOYAMBA 124 | RIVER_GEE GRAND_KRU 125 | RIVER_GEE MARYLAND 126 | GRAND_KRU MARYLAND 127 | MOYAMBA WESTERN 128 | -------------------------------------------------------------------------------- /data/ebola/district_adjacency.txt: -------------------------------------------------------------------------------- 1 | 0 55 gin-lbr-sle adj_mtrx_n 2 | BEYLA 5 3 | KANKAN KEROUANE LOLA MACENTA NZEREKORE 4 | BOFFA 4 5 | BOKE DUBREKA FRIA TELIMELE 6 | BOKE 2 7 | BOFFA TELIMELE 8 | CONAKRY 2 9 | DUBREKA COYAH 10 | COYAH 4 11 | CONAKRY DUBREKA FORECARIAH KINDIA 12 | DABOLA 3 13 | DINGUIRAYE FARANAH KOUROUSSA 14 | DALABA 3 15 | KINDIA PITA TOUGUE 16 | DINGUIRAYE 4 17 | DABOLA KOUROUSSA SIGUIRI TOUGUE 18 | DUBREKA 6 19 | BOFFA CONAKRY COYAH FRIA KINDIA TELIMELE 20 | FARANAH 5 21 | DABOLA GUECKEDOU KISSIDOUGO KOUROUSSA KOINADUGU 22 | FORECARIAH 4 23 | COYAH KINDIA BOMBALI KAMBIA 24 | FRIA 3 25 | BOFFA TELIMELE DUBREKA 26 | GUECKEDOU 7 27 | FARANAH KISSIDOUGO MACENTA LOFA KAILAHUN KOINADUGU KONO 28 | KANKAN 5 29 | BEYLA KEROUANE KISSIDOUGO KOUROUSSA SIGUIRI 30 | KEROUANE 4 31 | BEYLA KANKAN KISSIDOUGO MACENTA 32 | KINDIA 7 33 | COYAH DALABA DUBREKA FORECARIAH PITA TELIMELE BOMBALI 34 | KISSIDOUGO 6 35 | FARANAH GUECKEDOU KANKAN KEROUANE KOUROUSSA MACENTA 36 | KOUROUSSA 6 37 | DABOLA DINGUIRAYE FARANAH KANKAN KISSIDOUGO SIGUIRI 38 | LOLA 3 39 | BEYLA NZEREKORE NIMBA 40 | MACENTA 7 41 | BEYLA GUECKEDOU KEROUANE KISSIDOUGO NZEREKORE YOMOU LOFA 42 | MALI 0 43 | 44 | NZEREKORE 5 45 | BEYLA LOLA MACENTA YOMOU NIMBA 46 | PITA 3 47 | DALABA KINDIA TELIMELE 48 | SIGUIRI 3 49 | DINGUIRAYE KANKAN KOUROUSSA 50 | TELIMELE 6 51 | BOFFA BOKE DUBREKA FRIA KINDIA PITA 52 | TOUGUE 2 53 | DALABA DINGUIRAYE 54 | YOMOU 5 55 | MACENTA NZEREKORE BONG LOFA NIMBA 56 | BOMI 4 57 | BONG GBARPOLU GRAND_CAPE_MOUNT MONTSERRADO 58 | BONG 8 59 | YOMOU BOMI GBARPOLU GRAND_BASSA LOFA MARGIBI MONTSERRADO NIMBA 60 | GBARPOLU 7 61 | BOMI BONG GRAND_CAPE_MOUNT LOFA MONTSERRADO KAILAHUN KENEMA 62 | GRAND_BASSA 4 63 | BONG MARGIBI NIMBA RIVERCESS 64 | GRAND_CAPE_MOUNT 4 65 | BOMI GBARPOLU KENEMA PUJEHUN 66 | GRAND_GEDEH 4 67 | NIMBA RIVER_GEE RIVERCESS SINOE 68 | GRAND_KRU 3 69 | MARYLAND RIVER_GEE SINOE 70 | LOFA 6 71 | GUECKEDOU MACENTA YOMOU BONG GBARPOLU KAILAHUN 72 | MARGIBI 3 73 | BONG GRAND_BASSA MONTSERRADO 74 | MARYLAND 2 75 | GRAND_KRU RIVER_GEE 76 | MONTSERRADO 4 77 | MARGIBI GBARPOLU BONG BOMI 78 | NIMBA 8 79 | LOLA NZEREKORE YOMOU BONG GRAND_BASSA GRAND_GEDEH RIVERCESS SINOE 80 | RIVER_GEE 4 81 | GRAND_GEDEH GRAND_KRU MARYLAND SINOE 82 | RIVERCESS 4 83 | GRAND_BASSA SINOE NIMBA GRAND_GEDEH 84 | SINOE 5 85 | GRAND_GEDEH GRAND_KRU NIMBA RIVER_GEE RIVERCESS 86 | BO 5 87 | BONTHE KENEMA MOYAMBA PUJEHUN TONKOLILI 88 | BOMBALI 6 89 | FORECARIAH KINDIA KAMBIA KOINADUGU PORT_LOKO TONKOLILI 90 | BONTHE 3 91 | BO MOYAMBA PUJEHUN 92 | KAILAHUN 5 93 | GUECKEDOU GBARPOLU LOFA KENEMA KONO 94 | KAMBIA 3 95 | FORECARIAH BOMBALI PORT_LOKO 96 | KENEMA 7 97 | GBARPOLU GRAND_CAPE_MOUNT BO KAILAHUN KONO PUJEHUN TONKOLILI 98 | KOINADUGU 5 99 | FARANAH GUECKEDOU BOMBALI KONO TONKOLILI 100 | KONO 5 101 | GUECKEDOU KAILAHUN KENEMA KOINADUGU TONKOLILI 102 | MOYAMBA 5 103 | BO BONTHE PORT_LOKO TONKOLILI WESTERN 104 | PORT_LOKO 5 105 | BOMBALI KAMBIA MOYAMBA TONKOLILI WESTERN 106 | PUJEHUN 4 107 | GRAND_CAPE_MOUNT BO BONTHE KENEMA 108 | TONKOLILI 7 109 | BO BOMBALI KENEMA KOINADUGU KONO MOYAMBA PORT_LOKO 110 | WESTERN 2 111 | PORT_LOKO MOYAMBA -------------------------------------------------------------------------------- /data/ebola/shapefiles/gin-lbr-sle/gin-lbr-sle.gal: -------------------------------------------------------------------------------- 1 | 0 55 gin-lbr-sle adj_mtrx_n 2 | BEYLA 5 3 | KANKAN KEROUANE LOLA MACENTA NZEREKORE 4 | BOFFA 4 5 | BOKE DUBREKA FRIA TELIMELE 6 | BOKE 2 7 | BOFFA TELIMELE 8 | CONAKRY 2 9 | DUBREKA COYAH 10 | COYAH 4 11 | CONAKRY DUBREKA FORECARIAH KINDIA 12 | DABOLA 3 13 | DINGUIRAYE FARANAH KOUROUSSA 14 | DALABA 3 15 | KINDIA PITA TOUGUE 16 | DINGUIRAYE 4 17 | DABOLA KOUROUSSA SIGUIRI TOUGUE 18 | DUBREKA 6 19 | BOFFA CONAKRY COYAH FRIA KINDIA TELIMELE 20 | FARANAH 5 21 | DABOLA GUECKEDOU KISSIDOUGO KOUROUSSA KOINADUGU 22 | FORECARIAH 4 23 | COYAH KINDIA BOMBALI KAMBIA 24 | FRIA 3 25 | BOFFA TELIMELE DUBREKA 26 | GUECKEDOU 7 27 | FARANAH KISSIDOUGO MACENTA LOFA KAILAHUN KOINADUGU KONO 28 | KANKAN 5 29 | BEYLA KEROUANE KISSIDOUGO KOUROUSSA SIGUIRI 30 | KEROUANE 4 31 | BEYLA KANKAN KISSIDOUGO MACENTA 32 | KINDIA 7 33 | COYAH DALABA DUBREKA FORECARIAH PITA TELIMELE BOMBALI 34 | KISSIDOUGO 6 35 | FARANAH GUECKEDOU KANKAN KEROUANE KOUROUSSA MACENTA 36 | KOUROUSSA 6 37 | DABOLA DINGUIRAYE FARANAH KANKAN KISSIDOUGO SIGUIRI 38 | LOLA 3 39 | BEYLA NZEREKORE NIMBA 40 | MACENTA 7 41 | BEYLA GUECKEDOU KEROUANE KISSIDOUGO NZEREKORE YOMOU LOFA 42 | MALI 0 43 | 44 | NZEREKORE 5 45 | BEYLA LOLA MACENTA YOMOU NIMBA 46 | PITA 3 47 | DALABA KINDIA TELIMELE 48 | SIGUIRI 3 49 | DINGUIRAYE KANKAN KOUROUSSA 50 | TELIMELE 6 51 | BOFFA BOKE DUBREKA FRIA KINDIA PITA 52 | TOUGUE 2 53 | DALABA DINGUIRAYE 54 | YOMOU 5 55 | MACENTA NZEREKORE BONG LOFA NIMBA 56 | BOMI 4 57 | BONG GBARPOLU GRAND_CAPE_MOUNT MONTSERRADO 58 | BONG 8 59 | YOMOU BOMI GBARPOLU GRAND_BASSA LOFA MARGIBI MONTSERRADO NIMBA 60 | GBARPOLU 7 61 | BOMI BONG GRAND_CAPE_MOUNT LOFA MONTSERRADO KAILAHUN KENEMA 62 | GRAND_BASSA 4 63 | BONG MARGIBI NIMBA RIVERCESS 64 | GRAND_CAPE_MOUNT 4 65 | BOMI GBARPOLU KENEMA PUJEHUN 66 | GRAND_GEDEH 4 67 | NIMBA RIVER_GEE RIVERCESS SINOE 68 | GRAND_KRU 3 69 | MARYLAND RIVER_GEE SINOE 70 | LOFA 6 71 | GUECKEDOU MACENTA YOMOU BONG GBARPOLU KAILAHUN 72 | MARGIBI 3 73 | BONG GRAND_BASSA MONTSERRADO 74 | MARYLAND 2 75 | GRAND_KRU RIVER_GEE 76 | MONTSERRADO 4 77 | MARGIBI GBARPOLU BONG BOMI 78 | NIMBA 8 79 | LOLA NZEREKORE YOMOU BONG GRAND_BASSA GRAND_GEDEH RIVERCESS SINOE 80 | RIVER_GEE 4 81 | GRAND_GEDEH GRAND_KRU MARYLAND SINOE 82 | RIVERCESS 4 83 | GRAND_BASSA SINOE NIMBA GRAND_GEDEH 84 | SINOE 5 85 | GRAND_GEDEH GRAND_KRU NIMBA RIVER_GEE RIVERCESS 86 | BO 5 87 | BONTHE KENEMA MOYAMBA PUJEHUN TONKOLILI 88 | BOMBALI 6 89 | FORECARIAH KINDIA KAMBIA KOINADUGU PORT_LOKO TONKOLILI 90 | BONTHE 3 91 | BO MOYAMBA PUJEHUN 92 | KAILAHUN 5 93 | GUECKEDOU GBARPOLU LOFA KENEMA KONO 94 | KAMBIA 3 95 | FORECARIAH BOMBALI PORT_LOKO 96 | KENEMA 7 97 | GBARPOLU GRAND_CAPE_MOUNT BO KAILAHUN KONO PUJEHUN TONKOLILI 98 | KOINADUGU 5 99 | FARANAH GUECKEDOU BOMBALI KONO TONKOLILI 100 | KONO 5 101 | GUECKEDOU KAILAHUN KENEMA KOINADUGU TONKOLILI 102 | MOYAMBA 5 103 | BO BONTHE PORT_LOKO TONKOLILI WESTERN 104 | PORT_LOKO 5 105 | BOMBALI KAMBIA MOYAMBA TONKOLILI WESTERN 106 | PUJEHUN 4 107 | GRAND_CAPE_MOUNT BO BONTHE KENEMA 108 | TONKOLILI 7 109 | BO BOMBALI KENEMA KOINADUGU KONO MOYAMBA PORT_LOKO 110 | WESTERN 2 111 | PORT_LOKO MOYAMBA 112 | -------------------------------------------------------------------------------- /lib/graph_generation.py: -------------------------------------------------------------------------------- 1 | import itertools 2 | import json 3 | 4 | import pandas as pd 5 | import numpy as np 6 | import networkx as nx 7 | 8 | from .settings import EBOLA_BASE_GRAPH_FILE, EBOLA_SCALED_GRAPH_FILE 9 | 10 | 11 | def make_ebola_network(n_nodes, p_in, p_out, seed=None): 12 | """ 13 | Build the EBOLA network with `n_nodes` based on the network of connected 14 | districts. Each district is mapped into a cluster of size proportional to 15 | the population of the district. 16 | 17 | Arguments: 18 | ========== 19 | n_nodes : int 20 | Desired number of nodes. Note: the resulting graph may have one node 21 | more or less than this number due to clique size approximation. 22 | p_in : float 23 | Intra-cluster edge probability 24 | p_out : dict 25 | Inter-country edge probability. It is a dict of float keyed by country for the 26 | inter-cluster edge probability between clusters inside a country, with an extra key 27 | 'inter-country' for the probability of inter-cluster edge probability in different 28 | countries. 29 | 30 | Return: 31 | ======= 32 | graph : networkx.Graph 33 | Undirected propagation network 34 | """ 35 | # Load base graph 36 | with open(EBOLA_BASE_GRAPH_FILE, 'r') as f: 37 | base_graph_data = json.load(f) 38 | base_graph = nx.readwrite.json_graph.node_link_graph(base_graph_data) 39 | # Add inter-cluster edge probabilities 40 | for u, v, d in base_graph.edges(data=True): 41 | # If same country 42 | if base_graph.node[u]['country'] == base_graph.node[v]['country']: 43 | d['weight'] = p_out[base_graph.node[u]['country']] 44 | # If different country 45 | else: 46 | d['weight'] = p_out['inter-country'] 47 | # Add intra-cluster edge-probability 48 | for u in base_graph.nodes(): 49 | base_graph.add_edge(u, u, weight=p_in) 50 | # Replicate the base graph attributes to each cluster 51 | cluster_names = list(base_graph.nodes()) 52 | country_names = [d['country'] for n, d in base_graph.nodes(data=True)] 53 | cluster_sizes = [int(np.ceil(n_nodes * base_graph.node[u]['size'])) for u in cluster_names] 54 | nodes_district_name = np.repeat(cluster_names, cluster_sizes) 55 | nodes_country_name = np.repeat(country_names, cluster_sizes) 56 | n_nodes = sum(cluster_sizes) 57 | # Build the intra/inter cluster probability matrix 58 | base_adj = nx.adjacency_matrix(base_graph, weight='weight').toarray().astype(float) 59 | # Generate stoch block model graph 60 | graph = nx.generators.stochastic_block_model(cluster_sizes, base_adj, seed=seed) 61 | # Assign district attribute to each node 62 | for u, district, country in zip(graph.nodes(), nodes_district_name, nodes_country_name): 63 | graph.node[u]['district'] = district 64 | graph.node[u]['country'] = country 65 | # Sanity check for name assignment of each cluster 66 | num_unique_block_district = len(set([(node_data['block'], node_data['district']) for u, node_data in graph.nodes(data=True)])) 67 | assert num_unique_block_district == len(cluster_names) 68 | # Extract the giant component 69 | graph = max(nx.connected_component_subgraphs(graph), key=len) 70 | return graph 71 | -------------------------------------------------------------------------------- /script_many_jobs.py: -------------------------------------------------------------------------------- 1 | from multiprocessing import Process, cpu_count 2 | import argparse 3 | import json 4 | import sys 5 | import os 6 | 7 | import script_single_job 8 | 9 | if __name__ == "__main__": 10 | 11 | parser = argparse.ArgumentParser() 12 | parser.add_argument('-d', '--dir', dest='dir', type=str, 13 | required=True, help="Experiment directory") 14 | parser.add_argument('-n', '--n_sims', dest='n_sims', type=int, 15 | required=True, help="Number of simulatins per graph and per index") 16 | parser.add_argument('-p', '--pool', dest='n_workers', type=int, 17 | required=False, default=cpu_count() - 1, 18 | help="Size of the parallel pool") 19 | args = parser.parse_args() 20 | 21 | # Init the list of arguments for the workers 22 | pool_args = list() 23 | 24 | # Extract the list of parameter file in the experiment directory 25 | param_file_list = sorted([f for f in os.listdir(args.dir) if f.startswith('param')]) 26 | 27 | # Make pool args 28 | for param_filename in param_file_list: 29 | 30 | # Load parameters from file 31 | param_filename_full = os.path.join(args.dir, param_filename) 32 | if not os.path.exists(param_filename_full): 33 | raise FileNotFoundError('Input file `{:s}` not found.'.format(param_filename_full)) 34 | with open(param_filename_full, 'r') as param_file: 35 | param_dict = json.load(param_file) 36 | 37 | for net_idx in range(len(param_dict['network']['seed_list'])): 38 | for sim_idx in range(args.n_sims): 39 | # Extract suffix from param_filename 40 | suffix = '-'.join(param_filename.rstrip('.json').split('-')[1:]) 41 | # Add network index 42 | suffix += f'-net{net_idx:2>d}' 43 | # Add simulation infex 44 | suffix += f'-sim{sim_idx:2>d}' 45 | # Build output filename 46 | output_filename = f'output-{suffix}.json' 47 | # Redirect stdout 48 | stdout = os.path.join(args.dir, f"stdout-{suffix:s}") 49 | # Redirect stderr 50 | stderr = os.path.join(args.dir, f"stderr-{suffix:s}") 51 | # Add script arguments to the pool list 52 | pool_args.append( 53 | (args.dir, param_filename, output_filename, net_idx, stdout, stderr) 54 | ) 55 | 56 | print(f"Start {len(pool_args):d} experiments on a pool of {args.n_workers:d} workers") 57 | print(f"=============================================================================") 58 | 59 | proc_list = list() 60 | 61 | while len(pool_args) > 0: 62 | 63 | this_args = pool_args.pop() 64 | 65 | print("Start process with parameters:", this_args) 66 | 67 | proc = Process(target=script_single_job.run, args=this_args) 68 | proc_list.append(proc) 69 | proc.start() 70 | 71 | if len(proc_list) == args.n_workers: 72 | # Wait until all processes are done 73 | for proc in proc_list: 74 | proc.join() 75 | # Reset process list 76 | proc_list = list() 77 | print() 78 | 79 | print('Done.') 80 | -------------------------------------------------------------------------------- /lib/ogata/main_evaluate.py: -------------------------------------------------------------------------------- 1 | """ 2 | 3 | Evaluate results 4 | 5 | 6 | Usage: 7 | 8 | python maint_evaluate.py -i exp_name_1 exp_name_2 exp_name_3 9 | 10 | to run the evaluation on experiments `exp_name_1`, `exp_name_2`, and 11 | `exp_name_3`. 12 | 13 | """ 14 | import matplotlib.pyplot as plt 15 | import collections 16 | import argparse 17 | import joblib 18 | import os 19 | 20 | from analysis import Evaluation, MultipleEvaluations 21 | 22 | if __name__ == '__main__': 23 | parser = argparse.ArgumentParser() 24 | parser.add_argument('-i', nargs='+', dest='all_selected', 25 | help='Experiments to select', required=True) 26 | parser.add_argument("--switch-backend", dest='plt_backend', type=str, 27 | default="agg", help="Switch matplotlib backend") 28 | args = parser.parse_args() 29 | 30 | plt.switch_backend(args.plt_backend) 31 | 32 | # FIXME: what is this? 33 | multi_summary_from_dump = False 34 | 35 | # summary for multi setting comparison 36 | multi_summary = collections.defaultdict(dict) 37 | 38 | if not multi_summary_from_dump: 39 | 40 | # Analyze each selected experiment 41 | for expname in args.all_selected: 42 | 43 | print(f"=== Analyzing experiment: {expname:s}") 44 | 45 | data = joblib.load(os.path.join('temp_pickles', expname + '.pkl')) 46 | description = [d['name'] for d in data] 47 | dat = [d['dat'] for d in data] 48 | evaluation = Evaluation(dat, expname, description) 49 | 50 | multi_summary['Qs'][expname] = evaluation.data[0][0]['info']['Qx'][0] 51 | 52 | # Plot the experiment 53 | print("Make plots of the experiments:") 54 | evaluation.simulation_plot( 55 | process='X', filename='simulation_infection_summary', 56 | granularity=0.1, save=True) 57 | evaluation.simulation_plot( 58 | process='H', filename='simulation_treatment_summary', 59 | granularity=0.1, save=True) 60 | 61 | # evaluation.infections_and_interventions_complete(save=True) 62 | # evaluation.present_discounted_loss(plot=True, save=True) 63 | 64 | # # Compute Comparison analysis data 65 | # print("Compute comparison analysis data:") 66 | # summary_tup = evaluation.summarize_interventions_and_intensities() 67 | 68 | # summary_tup = evaluation.infections_and_interventions_complete( 69 | # size_tup=(16, 10), save=True) 70 | 71 | # multi_summary['infections_and_interventions'][saved[selected]] = summary_tup 72 | # multi_summary['stats_intervention_intensities'][saved[selected]] = summary_tup 73 | 74 | # eval.debug() 75 | 76 | # dum = (saved, all_selected, multi_summary) 77 | # joblib.dump(dum, 'multi_comp_dump_{}'.format(saved[all_selected[-1]])) 78 | 79 | else: 80 | dum = joblib.load('multi_comp_dump_{}'.format(saved[args.all_selected[-1]])) 81 | saved = dum[0] 82 | all_selected = dum[1] 83 | multi_summary = dum[2] 84 | 85 | # Comparative analysis 86 | # multi_eval = MultipleEvaluations(saved, all_selected[-1], multi_summary) 87 | # multi_eval.compare_infections(size_tup=(5.0, 3.7), save=True) 88 | -------------------------------------------------------------------------------- /lib/poisson_renewal.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from collections import defaultdict 3 | 4 | 5 | class ModelPoissonRenewal: 6 | 7 | def set_data(self, count_arr, beta, T): 8 | # Counts from 1:end 9 | self.count = torch.tensor(count_arr, dtype=torch.double)[1:] 10 | # Cumumaltive sum of counts for T-i:i for each i in 1:end 11 | self.count_cumsum = torch.tensor([count_arr[max(i-T-1,0):i].sum() for i in range(1, len(count_arr))], dtype=torch.double) 12 | self.beta = beta 13 | self.T = len(count_arr)-1 if T is None else T 14 | 15 | def log_likelihood(self, log_r): 16 | lamb = torch.exp(log_r) * self.beta * self.count_cumsum + 1e-10 17 | vals = self.count * torch.log(lamb) - torch.lgamma(self.count+1) - lamb 18 | return vals.sum() 19 | 20 | def objective(self, log_r): 21 | return -1.0 * self.log_likelihood(log_r) 22 | 23 | 24 | class Learner: 25 | 26 | def __init__(self, model, lr, lr_gamma, tol, max_iter): 27 | self.model = model 28 | self.lr = lr 29 | self.lr_gamma = lr_gamma 30 | self.tol = tol 31 | self.max_iter = max_iter 32 | 33 | def _set_data(self, count_arr, beta, T): 34 | self.model.set_data(count_arr, beta, T) 35 | 36 | def _check_convergence(self): 37 | if torch.abs(self.coeffs - self.coeffs_prev).max() < self.tol: 38 | return True 39 | return False 40 | 41 | def fit(self, x0, daily_count_arr, beta, T=None, callback=None): 42 | self._set_data(daily_count_arr, beta, T) 43 | 44 | self.coeffs = x0.clone().detach().requires_grad_(True) 45 | self.coeffs_prev = self.coeffs.detach().clone() 46 | 47 | self.optimizer = torch.optim.Adam([self.coeffs], lr=self.lr) 48 | self.scheduler = torch.optim.lr_scheduler.ExponentialLR( 49 | self.optimizer, gamma=self.lr_gamma) 50 | 51 | for t in range(self.max_iter): 52 | self._n_iter_done = t 53 | # Gradient step 54 | self.optimizer.zero_grad() 55 | self.loss = self.model.objective(self.coeffs) 56 | self.loss.backward() 57 | self.optimizer.step() 58 | self.scheduler.step() 59 | 60 | if torch.isnan(self.coeffs).any(): 61 | raise ValueError('NaNs in coeffs! Stop optimization...') 62 | if torch.isnan(self.loss).any(): 63 | raise ValueError('NaNs is loss! Stop optimization...') 64 | 65 | # Convergence check 66 | if self._check_convergence(): 67 | break 68 | elif callback: # Callback at each iteration 69 | callback(self, end='') 70 | self.coeffs_prev = self.coeffs.detach().clone() 71 | if callback: # Callback before the end 72 | callback(self, end='\n', force=True) 73 | return self.coeffs 74 | 75 | 76 | class CallbackMonitor: 77 | 78 | def __init__(self, print_every=10): 79 | self.print_every = print_every 80 | 81 | def __call__(self, learner_obj, end='', force=False): 82 | t = learner_obj._n_iter_done + 1 83 | if force or (t % self.print_every == 0): 84 | dx = torch.abs(learner_obj.coeffs - learner_obj.coeffs_prev).max() 85 | print("\r " 86 | f"iter: {t:>4d}/{learner_obj.max_iter:>4d} | " 87 | f"R: {learner_obj.coeffs[0]:.4f} | " 88 | f"loss: {learner_obj.loss:.4f} | " 89 | f"dx: {dx:.2e}" 90 | " ", end=end, flush=True) 91 | -------------------------------------------------------------------------------- /lib/ogata/main_simulate.py: -------------------------------------------------------------------------------- 1 | """ 2 | 3 | Main script to simulate experiments 4 | 5 | """ 6 | import matplotlib.pyplot as plt 7 | import networkx as nx 8 | import numpy as np 9 | import argparse 10 | import joblib 11 | import os 12 | 13 | from experiment import Experiment 14 | 15 | def build_filename(output_dir, exp): 16 | """ 17 | Find and construct an available filename for the experiment `exp`. 18 | The file is index by _v0, _v1, _v2, ... 19 | """ 20 | filename_prefix = (f"{exp.name:s}_" 21 | f"Q_{exp.cost_dict['Qlam']:.0f}_" 22 | f"{exp.cost_dict['Qx']:.0f}") 23 | filepath_prefix = os.path.join(output_dir, filename_prefix) 24 | j = 0 25 | filename = f"{filepath_prefix}_v{j}.pkl" 26 | while os.path.exists(filename): 27 | j += 1 28 | filename = f"{filepath_prefix}_v{j}.pkl" 29 | return filename 30 | 31 | 32 | if __name__ == '__main__': 33 | parser = argparse.ArgumentParser() 34 | parser.add_argument('-o', dest='output_dir', type=str, 35 | help='Output directory', required=True) 36 | args = parser.parse_args() 37 | 38 | plt.switch_backend('agg') 39 | 40 | print(f'Start simulation in output dir: {args.output_dir:s}') 41 | if not os.path.exists(args.output_dir): 42 | print('Output dir does not exists. Create it.') 43 | os.makedirs(args.output_dir) 44 | print() 45 | 46 | # Construct the adjacency matrix A of the propagation network 47 | net = nx.read_edgelist('data/contiguous-usa.txt') 48 | A = nx.adjacency_matrix(net).toarray().astype(float) 49 | n_nodes = net.number_of_nodes() 50 | n_edges = net.number_of_edges() 51 | print(f"Network: {n_nodes:d} nodes, {n_edges:d} edges") 52 | print() 53 | 54 | # Initial infections 55 | print('Choose set of initial infected seeds') 56 | infected = 10 57 | X_init = np.hstack(((np.ones(infected), np.zeros(n_nodes - infected)))) 58 | X_init = np.random.permutation(X_init) 59 | print() 60 | 61 | # Experiments 62 | experiments = [] 63 | for qx in [1, 10, 25, 50, 100, 150, 200, 300, 400, 500]: 64 | exp = Experiment( 65 | name='test_all', 66 | sim_dict={ 67 | 'total_time': 10.00, 68 | 'trials_per_setting': 5 69 | }, 70 | param_dict={ 71 | 'beta': 6.0, 72 | 'gamma': 5.0, 73 | 'delta': 1.0, 74 | 'rho': 5.0, 75 | 'eta': 1.0 76 | }, 77 | cost_dict={ 78 | 'Qlam': 1.0, 79 | 'Qx': qx 80 | }, 81 | policy_list=[ 82 | 'SOC', 83 | 'TR', 'TR-FL', 84 | 'MN', 'MN-FL', 85 | 'LN', 'LN-FL', 86 | 'LRSR', 87 | 'MCM', 88 | ], 89 | baselines_dict={ 90 | 'TR': 0.003, 91 | 'MN': 0.0007, 92 | 'LN': 0.0008, 93 | 'LRSR': 22.807, 94 | 'MCM': 22.807, 95 | 'FL_info': {'N': None, 'max_u': None}, 96 | }) 97 | experiments.append(exp) 98 | 99 | # Simulation (Nothing below should be changed) 100 | for i, exp in enumerate(experiments): 101 | filename = build_filename(args.output_dir, exp) 102 | print(f"\nRunning experiment {i+1:d}/{len(experiments)}: `{filename}`...") 103 | data = exp.run(A, X_init) 104 | print(f"Save the simulation to: {filename:s}") 105 | joblib.dump(data, filename) 106 | -------------------------------------------------------------------------------- /lib/ogata/stochastic_processes.py: -------------------------------------------------------------------------------- 1 | import bisect 2 | 3 | 4 | class StochasticProcess: 5 | """ 6 | General class that handles values of stochastic processes over time. 7 | 8 | Attributes 9 | ---------- 10 | arrival_times : list 11 | List of arrival times of the process 12 | last_arrival : float 13 | Time of the last arrival 14 | values : list 15 | List of values taken by the process at each arrival 16 | """ 17 | 18 | def __init__(self, init_value=0.0): 19 | self.last_arrival = -1.0 20 | self.arrival_times = [] 21 | self.values = [init_value] 22 | 23 | def get_last_arrival_time(self): 24 | """ 25 | Return the time of the last arrival; None if no arrival happened yet. 26 | """ 27 | return self.last_arrival if self.last_arrival >= 0.0 else None 28 | 29 | def get_current_value(self): 30 | """Return current value of stochastic process.""" 31 | return self.values[-1] 32 | 33 | def get_arrival_times(self): 34 | """Return the list of arrival times.""" 35 | return self.arrival_times 36 | 37 | def generate_arrival_at(self, t, dN=1.0, N=None): 38 | """Generate an arrival at time t with dN or set to value N if given.""" 39 | # Check that arrival time happens in the present 40 | if t < self.last_arrival: 41 | raise ValueError((f"The provided arrival time t=`{t}` is prior to " 42 | f"the last recorded arrival time " 43 | f"`{self.last_arrival}`")) 44 | self.last_arrival = t 45 | self.arrival_times.append(t) 46 | if N is None: 47 | self.values.append(self.values[-1] + dN) 48 | else: 49 | self.values.append(N) 50 | 51 | def value_at(self, t): 52 | """Return the value of the stochastic process at time `t`.""" 53 | j = bisect.bisect_right(self.arrival_times, t) 54 | return self.values[j] 55 | 56 | 57 | class CountingProcess(StochasticProcess): 58 | """ 59 | General class that handles values of a counting processes over time. 60 | A counting process is a particular type of stochastic process that is 61 | initialized at zero (i.e. `N(0)=0`) and has unit increments (i.e. `dN=1`). 62 | """ 63 | 64 | def __init__(self): 65 | super().__init__(init_value=0.0) 66 | 67 | def generate_arrival_at(self, t): 68 | """ 69 | Generate the counting process arrival 70 | """ 71 | super().generate_arrival_at(t, 1.0) 72 | 73 | 74 | if __name__ == '__main__': 75 | """Do some basic unit testing.""" 76 | 77 | s = StochasticProcess(init_value=0.0) 78 | 79 | s.generate_arrival_at(1.0, 1.0) 80 | s.generate_arrival_at(5.0, 1.0) 81 | s.generate_arrival_at(10.0, -1.0) 82 | 83 | assert(s.get_arrival_times() == [1.0, 5.0, 10.0]) 84 | assert(s.get_current_value() == 1.0) 85 | assert(s.get_last_arrival_time() == 10.0) 86 | 87 | tests = zip([-2.0, -1.0, 0.0, 1.0, 3.0, 5.0, 7.0, 10.0, 11.0], 88 | [0.0, 0.0, 0.0, 1.0, 1.0, 2.0, 2.0, 1.0, 1.0]) 89 | for t, sol in tests: 90 | assert(sol == s.value_at(t)) 91 | 92 | c = CountingProcess() 93 | c.generate_arrival_at(1.0) 94 | c.generate_arrival_at(5.0) 95 | 96 | assert(c.get_arrival_times() == [1.0, 5.0]) 97 | assert(c.get_current_value() == 2.0) 98 | assert(c.get_last_arrival_time() == 5.0) 99 | 100 | tests2 = zip([-2.0, -1.0, 0.0, 1.0, 3.0, 5.0, 7.0], 101 | [0.0, 0.0, 0.0, 1.0, 1.0, 2.0, 2.0]) 102 | 103 | for t, sol in tests: 104 | assert(sol == c.value_at(t)) 105 | 106 | print("Unit tests successful.") 107 | -------------------------------------------------------------------------------- /notebooks/2-calibration/2-baseline-calibration.py: -------------------------------------------------------------------------------- 1 | # 2 | # Calibration of baseline control parameters 3 | # 4 | 5 | import os 6 | import json 7 | import copy 8 | import itertools 9 | from collections import Counter, defaultdict 10 | import pandas as pd 11 | import networkx as nx 12 | import numpy as np 13 | from multiprocessing import cpu_count, Pool 14 | 15 | from lib.graph_generation import make_ebola_network 16 | from lib.dynamics import SimulationSIR, PriorityQueue 17 | from lib.dynamics import sample_seeds 18 | from lib.settings import DATA_DIR 19 | from lib import metrics 20 | 21 | 22 | # 1. Set simulation parameters 23 | # ============================ 24 | 25 | # Set simulation parameters 26 | start_day_str = '2014-01-01' 27 | end_day_str = '2014-04-01' 28 | max_timedelta = pd.to_datetime(end_day_str) - pd.to_datetime(start_day_str) 29 | max_days = max_timedelta.days 30 | 31 | # Set SIR infection and recovery rates 32 | beta = 1 / 15.3 33 | delta = 1 / 11.4 34 | gamma = beta 35 | rho = 0.0 36 | 37 | # Set the network parameters. 38 | # n_nodes = 8000 39 | # p_in = 0.01 40 | # p_out = { 41 | # 'Guinea': 0.00215, 42 | # 'Liberia': 0.00300, 43 | # 'Sierra Leone': 0.00315, 44 | # 'inter-country': 0.0019 45 | # } 46 | n_nodes = 800 47 | p_in = 0.01 48 | p_out = { 49 | 'Guinea': 0.0215, 50 | 'Liberia': 0.0300, 51 | 'Sierra Leone': 0.0315, 52 | 'inter-country': 0.019 53 | } 54 | 55 | 56 | # Set the control parameters. 57 | DEFAULT_POLICY_PARAMS = { 58 | # SOC parameters 59 | 'eta': 1.0, # SOC exponential decay 60 | 'q_x': None, # SOC infection cost 61 | 'q_lam': 1.0, # SOC recovery cost 62 | 'lpsolver': 'cvxopt', # SOC linear progam solver 63 | 64 | # Scaling of baseline methods 65 | 'TR': None, 66 | 'MN': None, 67 | 'LN': None, 68 | 'LRSR': None, 69 | 'MCM': None, 70 | 'front-loading': { # Front-loading parameters (will be set after the SOC run) 71 | 'max_interventions': None, 72 | 'max_lambda': None 73 | } 74 | } 75 | 76 | 77 | # 2. Run calibration 78 | # ================== 79 | 80 | 81 | def worker(policy, policy_params, n_sims): 82 | graph = make_ebola_network(n_nodes=n_nodes, p_in=p_in, p_out=p_out) 83 | print(f'graph: {graph.number_of_nodes()} nodes, {graph.number_of_edges()} edges') 84 | 85 | init_event_list = sample_seeds(graph, delta=delta, method='data', 86 | max_date=start_day_str, verbose=False) 87 | 88 | res_dict = {'max_u': list(), 'n_tre': list()} 89 | 90 | for sim_idx in range(n_sims): 91 | print(f"\rSim {sim_idx+1}/{n_sims}", end="") 92 | 93 | sir_obj = SimulationSIR(graph, beta=beta, delta=delta, gamma=gamma, rho=rho, verbose=True) 94 | sir_obj.launch_epidemic( 95 | init_event_list=init_event_list, 96 | max_time=max_days, 97 | policy=policy, 98 | policy_dict=policy_params 99 | ) 100 | 101 | res_dict['max_u'].append(sir_obj.max_total_control_intensity) 102 | res_dict['n_tre'].append(sir_obj.is_tre.sum()) 103 | 104 | return res_dict 105 | 106 | 107 | q_x_range = [1.0, 10.0, 25.0, 50.0, 100.0, 250.0, 500.0, 750.0, 1000.0] 108 | num_nets = 2 109 | num_sims = 2 110 | 111 | args_list = list() 112 | for i, q_x in enumerate(q_x_range): 113 | print(f'=== q_x {i+1} / {len(q_x_range)}') 114 | 115 | for net_idx in range(num_nets): 116 | print(f'--- Network {net_idx+1} / {num_nets}') 117 | 118 | policy_params = copy.deepcopy(DEFAULT_POLICY_PARAMS) 119 | policy_params['q_x'] = q_x 120 | 121 | args_list.append(('SOC', policy_params, num_sims)) 122 | 123 | # Run experiments in parallel 124 | pool = Pool(cpu_count()-1) 125 | res_list = pool.starmap(worker, args_list) 126 | 127 | # Save results 128 | res_df = pd.DataFrame(res_list) 129 | res_df.to_csv('baseline-calibration-soc.csv', index=False) 130 | -------------------------------------------------------------------------------- /lib/ogata/helpers.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | class HelperFunc: 5 | """ 6 | Helper functions 7 | """ 8 | 9 | def __init__(self): 10 | pass 11 | 12 | def all_arrivals(self, sps): 13 | """ 14 | Extract all the arrivals times from the stochastic processes in `sps`. 15 | 16 | Parameters 17 | ---------- 18 | sps : numpy.ndarray 19 | Array of `StochasticProcess` 20 | 21 | Returns 22 | ------- 23 | arrival_list : list 24 | Flat list containing all the arrivals in `sps` 25 | """ 26 | arrivals = [sp.get_arrival_times() for sp in sps] 27 | return sorted([arrival for sublist in arrivals for arrival in sublist]) 28 | 29 | def sps_values(self, sps, t, summed=False): 30 | """ 31 | Return the value of all stochastic processes in `sps` at time `t`. 32 | 33 | Parameters 34 | ---------- 35 | sps : numpy.ndarray 36 | Array of `StochasticProcess` 37 | summed : bool, optional (default: False) 38 | If True, returns total number of sps at time t where 39 | sp.value_at(t) == 1 40 | 41 | Returns 42 | ------- 43 | value_arr : np.ndarray 44 | array in {0,1}^N of sps at time t where sp.value_at(t) == 1 45 | where every StochasticProcess sp in {0,1} 46 | 47 | """ 48 | vconvert = np.vectorize(lambda sp: sp.value_at(t)) 49 | if summed: 50 | return np.add.reduce(vconvert(sps)) 51 | else: 52 | return vconvert(sps) 53 | 54 | def sps_values_over_time(self, sps, summed=False): 55 | """ 56 | Return the value of all stochastic processes evaluated at the union of 57 | all arrival times 58 | 59 | Parameters 60 | ---------- 61 | sps : numpy.ndarray 62 | Array of `StochasticProcess` 63 | summed : bool, optional (default: False) 64 | If True, returns total number of sps where sp.value_at(t) == 1 at 65 | every arrival time of sps 66 | 67 | Returns 68 | ------- 69 | sps_list : list 70 | list of sps where sp.value_at(t) == 1 at every arrival time of 71 | object sp in sps 72 | """ 73 | all_t = self.all_arrivals(sps) 74 | if summed: 75 | return [self.sps_values(sps, t, summed=True) for t in all_t] 76 | else: 77 | return [self.sps_values(sps, t, summed=False) for t in all_t] 78 | 79 | def step_sps_values_over_time(self, sps, summed=False): 80 | """ 81 | Helper function to compute: 82 | t = `all_arrivals(sps)`, and 83 | X = `sps_values_over_time(sps, summed=summed)`) 84 | augmented to cover constant intervals: 85 | every two consecutive pairs (t1, t2), (t3, t4),.. in 86 | t = [t1, t2, t3, t4, ..] returned has a constant values in X 87 | i.e. X[t1] == X[t2], X[t3] == X[t4], ... 88 | 89 | Parameters 90 | ---------- 91 | sps : numpy.ndarray 92 | Array of `StochasticProcess` 93 | summed : bool, optional (default: False) 94 | If True, returns total number of sps where sp.value_at(t) == 1 at 95 | every arrival time of sps 96 | 97 | Returns 98 | ------- 99 | t : list 100 | Flat list containing all the arrivals in `sps` 101 | X : list 102 | Flat list of values taken the stochastic processes in `sps` 103 | """ 104 | t_ = self.all_arrivals(sps) 105 | y_ = self.sps_values_over_time(sps, summed=summed) 106 | t = [] if len(t_) == 0 else \ 107 | [0.0] + [val for val in t_ for _ in (0, 1)][:2 * len(t_) - 1] 108 | y = [val for val in y_ for _ in (0, 1)] 109 | return t, y 110 | 111 | 112 | if __name__ == '__main__': 113 | 114 | # Basic unit testing 115 | print("TODO") 116 | -------------------------------------------------------------------------------- /notebooks/0-tutorial/2-network-generation.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Network generation (and SIR Simulation)" 8 | ] 9 | }, 10 | { 11 | "cell_type": "code", 12 | "execution_count": 14, 13 | "metadata": {}, 14 | "outputs": [], 15 | "source": [ 16 | "import sys\n", 17 | "if '../../' not in sys.path:\n", 18 | " sys.path.append('../../')\n", 19 | "\n", 20 | "import numpy as np\n", 21 | "import networkx as nx\n", 22 | " \n", 23 | "from lib.graph_generation import make_ebola_network\n", 24 | "from lib.dynamics import SimulationSIR, sample_seeds" 25 | ] 26 | }, 27 | { 28 | "cell_type": "markdown", 29 | "metadata": {}, 30 | "source": [ 31 | "Predefined network settings" 32 | ] 33 | }, 34 | { 35 | "cell_type": "code", 36 | "execution_count": 15, 37 | "metadata": {}, 38 | "outputs": [], 39 | "source": [ 40 | "n_nodes = 8000 # Desired number of nodes (we only keep the giant component, so actual number of nodes may be smaller)\n", 41 | "p_in = 0.01 # Intra-district edge probability\n", 42 | "p_out = { # Inter-district edge probability\n", 43 | " 'Guinea': 0.00215,\n", 44 | " 'Liberia': 0.00300, \n", 45 | " 'Sierra Leone': 0.00315, \n", 46 | " 'inter-country': 0.0019\n", 47 | "}\n", 48 | "\n", 49 | "# Generate scaled graph with settings\n", 50 | "G = make_ebola_network(n_nodes=n_nodes, p_in=p_in, p_out=p_out, seed=None)" 51 | ] 52 | }, 53 | { 54 | "cell_type": "markdown", 55 | "metadata": {}, 56 | "source": [ 57 | "Run simulation on graph with predefined settings" 58 | ] 59 | }, 60 | { 61 | "cell_type": "code", 62 | "execution_count": 17, 63 | "metadata": {}, 64 | "outputs": [ 65 | { 66 | "name": "stdout", 67 | "output_type": "stream", 68 | "text": [ 69 | "Add seed 7722 from district WESTERN - inf: 0.0, rec: 3.7409395861956174 \n", 70 | "Add seed 6407 from district PUJEHUN - inf: 0.0, rec: 4.430468186283136 \n", 71 | "Add seed 7287 from district TONKOLILI - inf: 0.0, rec: 16.144704236934587 \n", 72 | "Add seed 7901 from district WESTERN - inf: 0.0, rec: 8.80503599378464 \n", 73 | "Add seed 3574 from district GUECKEDOU - inf: 0.0, rec: 13.951449265672295 \n", 74 | "Add seed 3515 from district GUECKEDOU - inf: 0.0, rec: 6.234714560266315 \n" 75 | ] 76 | } 77 | ], 78 | "source": [ 79 | "# Sample seeds from the dataset (random set from infections seen before the given date, taking into account possible random recovery)\n", 80 | "init_event_list = sample_seeds(graph=G, delta=1.0 / 11.4, method='data', max_date='2014-01-01')\n", 81 | "\n", 82 | "# Initialize object\n", 83 | "sir_obj = SimulationSIR(G, # Graph of individuals\n", 84 | " beta=1.0 / 15.3, # Infection rate (advised by literature)\n", 85 | " delta=1.0 / 11.4, # Recovery rate (advised by literature)\n", 86 | " gamma=0.0, rho=0.0, # Treatement values, should remain zero\n", 87 | " verbose=False)\n", 88 | "\n", 89 | "# Run the sumulation\n", 90 | "sir_obj.launch_epidemic(init_event_list=init_event_list, max_time=90)" 91 | ] 92 | }, 93 | { 94 | "cell_type": "code", 95 | "execution_count": null, 96 | "metadata": {}, 97 | "outputs": [], 98 | "source": [] 99 | } 100 | ], 101 | "metadata": { 102 | "kernelspec": { 103 | "display_name": "Python 3", 104 | "language": "python", 105 | "name": "python3" 106 | }, 107 | "language_info": { 108 | "codemirror_mode": { 109 | "name": "ipython", 110 | "version": 3 111 | }, 112 | "file_extension": ".py", 113 | "mimetype": "text/x-python", 114 | "name": "python", 115 | "nbconvert_exporter": "python", 116 | "pygments_lexer": "ipython3", 117 | "version": "3.7.11" 118 | } 119 | }, 120 | "nbformat": 4, 121 | "nbformat_minor": 4 122 | } 123 | -------------------------------------------------------------------------------- /lib/ogata/visualize_OPT.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import numpy as np 3 | from pprint import pprint 4 | import matplotlib.pyplot as plt 5 | import scipy.stats 6 | from tqdm import tqdm 7 | import scipy.optimize 8 | import joblib 9 | import networkx as nx 10 | 11 | from dynamics import SISDynamicalSystem 12 | from analysis import Evaluation 13 | from stochastic_processes import StochasticProcess, CountingProcess 14 | 15 | from helpers import HelperFunc 16 | 17 | 18 | '''Construct network from input''' 19 | df = pd.read_csv("data/contiguous-usa.txt", sep=" ", header=None) 20 | df = df - 1 21 | df = df.drop(columns=[2]) 22 | df_0, df_1 = df[0].values, df[1].values 23 | N, M = max(np.max(df_0), np.max(df_1)) + 1, len(df_0) 24 | 25 | # load X, u 26 | print('load X, u...') 27 | hf = HelperFunc() 28 | filename = 'results_comparison_OPT_T_MN_complete_8_50__Q_1_400_.pkl' 29 | data = joblib.load('temp_pickles/' + filename) 30 | print('done.') 31 | 32 | 33 | 34 | times = np.arange(1.0, 4.0, 0.1).tolist() 35 | times2 = np.arange(1.001, 4.001, 0.1).tolist() 36 | times = times + times + times2 37 | times.sort() 38 | 39 | trial_no = 1 40 | 41 | for t in times: 42 | 43 | u_t = hf.sps_values(data[0][trial_no]['u'], t, summed=False) 44 | X_t = hf.sps_values(data[0][trial_no]['X'], t, summed=False) 45 | 46 | G = nx.from_pandas_edgelist(df.astype('i'), 0, 1, create_using=nx.Graph()) 47 | 48 | # And a data frame with characteristics for your nodes 49 | # X = [0 if i % 4 == 0 else 1 for i in range(N)] 50 | # u = [i for i in range(N)] 51 | X = list(X_t) 52 | u = list(u_t) 53 | 54 | # Infection tags 55 | carac = pd.DataFrame({'ID': G.nodes(), 'myvalue': X}) 56 | carac['myvalue'] = pd.Categorical(carac['myvalue']) 57 | 58 | pos = nx.spring_layout(G, k=0.04) 59 | 60 | 61 | nodes_to_u = {i: u[i] for i in G.nodes()} 62 | nodes_to_X = {i: X[i] for i in G.nodes()} 63 | 64 | X_to_edgecolors = ['red' if X[i] == 1.0 else 'black' for i in G.nodes()] 65 | X_to_linewidths = [2.0 if X[i] == 1.0 else 0.7 for i in G.nodes()] 66 | 67 | healthy_nodes = [i for i in G.nodes() if X[i] != 1.0] 68 | infected_nodes = [i for i in G.nodes() if X[i] == 1.0 and abs(u[i]) == 0.0] 69 | infected_treated_nodes = [i for i in G.nodes() if X[i] == 1.0 and abs(u[i]) != 0.0] 70 | # add one infected node with u = 0 to 'under treatment' for proper color mapping 71 | if infected_nodes: 72 | v = infected_nodes.pop() 73 | infected_treated_nodes.append(v) 74 | infected_treated_nodes_to_u = {i: u[i] for i in infected_treated_nodes} 75 | 76 | 77 | plt.figure(figsize=(6, 4)) 78 | nx.draw_networkx_edges(G, pos, nodelist=list(nodes_to_u.keys()), alpha=0.4) 79 | nx.draw_networkx_nodes(G, pos, nodelist=infected_treated_nodes, 80 | node_size=100, 81 | node_color=list(infected_treated_nodes_to_u.values()), 82 | cmap=plt.cm.Blues, 83 | # node_color='blue', 84 | linewidths=2.5, 85 | edgecolors='black', 86 | label='infected and targeted for treatment') 87 | nx.draw_networkx_nodes(G, pos, nodelist=infected_nodes, 88 | node_size=100, 89 | node_color='white', 90 | # cmap=plt.cm.Blues, 91 | linewidths=2.5, 92 | edgecolors='black', 93 | label='infected') 94 | nx.draw_networkx_nodes(G, pos, nodelist=healthy_nodes, 95 | node_size=100, 96 | node_color='white', 97 | # cmap=plt.cm.Blues, 98 | linewidths=0.7, 99 | edgecolors='black', 100 | label='healthy') 101 | plt.axis('off') 102 | plt.legend(numpoints=1) 103 | plt.savefig('graphs/network_visualization_3_{}.png'.format(int(t * 1000)), frameon=False, format='png', dpi=600) 104 | # plt.show() 105 | plt.close('all') 106 | 107 | -------------------------------------------------------------------------------- /lib/maxcut.py: -------------------------------------------------------------------------------- 1 | """ 2 | 3 | Implement the MaxCutMinimization (MCM) strategy to suppress SIS epidemics, as 4 | defined in the following paper. 5 | 6 | K. Scaman, A. Kalogeratos and N. Vayatis, "Suppressing Epidemics in Networks 7 | Using Priority Planning," in IEEE Transactions on Network Science and 8 | Engineering, vol. 3, no. 4, pp. 271-285, Oct.-Dec. 2016. 9 | 10 | """ 11 | import networkx as nx 12 | import numpy as np 13 | import scipy as sp 14 | import math 15 | import time 16 | 17 | 18 | def _spectral_sequencing(adjacency_mat): 19 | """ 20 | Order the nodes according to the eigenvector related to the second 21 | smallest eigenvalue of the Laplacian matrix. 22 | """ 23 | degree_mat = np.diag(adjacency_mat.sum(axis=0)) 24 | laplacian_mat = degree_mat - adjacency_mat 25 | # Compute the second largest eigenvalue 26 | _, e_2 = sp.linalg.eigh(laplacian_mat, eigvals=(1, 1)) 27 | return np.argsort(e_2.squeeze()) 28 | 29 | 30 | def _swap_heuristic(adjacency_mat, plist, seed, n_swaps): 31 | """ 32 | Randomized heuristic that starts from the priority list `plist` and tries 33 | to improve the maxcut. At each iteration, it applies one random swap and 34 | keep the modifications if this swap results in a lower sum of cuts. 35 | """ 36 | G = nx.Graph(adjacency_mat) 37 | n_nodes = len(adjacency_mat) 38 | # Initialize the random seed 39 | if seed: 40 | np.random.seed(seed) 41 | # Initialize the current sum of cuts 42 | curr_sum_cuts = sum(cut_list(G, plist)) 43 | last_print = time.time() 44 | for i in range(n_swaps): 45 | # printing 46 | if (time.time() - last_print > 0.1): 47 | last_print = time.time() 48 | done = 100 * i / n_swaps 49 | print('\r', f'Computing MCM heuristic... {done:.2f}%', 50 | sep='', end='', flush=True) 51 | # Sample two nodes ramdomly 52 | x, y = np.random.randint(0, n_nodes, size=2) 53 | # Swap their order 54 | plist[x], plist[y] = plist[y], plist[x] 55 | # Compute the new sum of cuts 56 | new_sum_cuts = sum(cut_list(G, plist)) 57 | # If improvement, update the sum of cuts, else ignore 58 | if new_sum_cuts < curr_sum_cuts: 59 | curr_sum_cuts = new_sum_cuts 60 | else: 61 | plist[x], plist[y] = plist[y], plist[x] 62 | return plist 63 | 64 | 65 | def mcm(adjacency_mat, seed=None, n_swaps=None): 66 | """ 67 | Compute the MaxCutMinimization priority planning based on spectral 68 | sequencing and a random swap heuristic. 69 | """ 70 | if not isinstance(adjacency_mat, np.ndarray): 71 | raise TypeError('The adjacency matrix must be a numpy ndarray.') 72 | if not len(adjacency_mat.shape) == 2: 73 | raise ValueError('The adjacency matrix should be of dimension 2.') 74 | if adjacency_mat.shape[0] != adjacency_mat.shape[1]: 75 | raise ValueError('The adjacency matrix should squared.') 76 | if not np.allclose(adjacency_mat, adjacency_mat.T): 77 | raise ValueError('The adjacency matrix should be symmetric.') 78 | plist = _spectral_sequencing(adjacency_mat) 79 | if n_swaps is None: 80 | n_swaps = len(adjacency_mat) 81 | plist = _swap_heuristic(adjacency_mat, plist, seed, n_swaps) 82 | return plist 83 | 84 | 85 | def cut_list(adjacency_mat, plist): 86 | """ 87 | Compute the cuts of the priority list 'plist'. Return an array where the 88 | element in position `i` is the cut between nodes `i` and `i+1` in the 89 | priority list. 90 | """ 91 | n_nodes = len(adjacency_mat) 92 | G = nx.Graph(adjacency_mat) 93 | cut_list = np.zeros(n_nodes, dtype='int') 94 | visited = set() 95 | pending_edges = 0 96 | for x, i in zip(plist, range(n_nodes)): 97 | for neigh in G.neighbors(x): 98 | if neigh not in visited: 99 | pending_edges += 1 100 | else: 101 | pending_edges -= 1 102 | cut_list[i] = pending_edges 103 | visited.add(x) 104 | return cut_list 105 | -------------------------------------------------------------------------------- /notebooks/2-calibration/script_baseline_calibration_soc.py: -------------------------------------------------------------------------------- 1 | # 2 | # Calibration of baseline control parameters 3 | # 4 | 5 | import os 6 | import json 7 | import copy 8 | import time 9 | import pandas as pd 10 | from multiprocessing import cpu_count, Pool 11 | 12 | from lib.graph_generation import make_ebola_network 13 | from lib.dynamics import SimulationSIR 14 | from lib.dynamics import sample_seeds 15 | from lib.settings import PROJECT_DIR 16 | 17 | 18 | # 1. Set simulation parameters 19 | # ============================ 20 | 21 | # Set simulation parameters 22 | start_day_str = '2014-01-01' 23 | end_day_str = '2014-04-01' 24 | max_timedelta = pd.to_datetime(end_day_str) - pd.to_datetime(start_day_str) 25 | max_days = max_timedelta.days 26 | 27 | # Set SIR infection and recovery rates 28 | beta = 1 / 15.3 29 | delta = 1 / 11.4 30 | gamma = beta 31 | rho = 0.0 32 | 33 | # Set the network parameters. 34 | n_nodes = 8000 35 | p_in = 0.01 36 | p_out = { 37 | 'Guinea': 0.00215, 38 | 'Liberia': 0.00300, 39 | 'Sierra Leone': 0.00315, 40 | 'inter-country': 0.0019 41 | } 42 | 43 | # Set the control parameters. 44 | DEFAULT_POLICY_PARAMS = { 45 | # SOC parameters 46 | 'eta': 1.0, # SOC exponential decay 47 | 'q_x': None, # SOC infection cost 48 | 'q_lam': 1.0, # SOC recovery cost 49 | 'lpsolver': 'cvxopt', # SOC linear progam solver 50 | 51 | # Scaling of baseline methods 52 | 'TR': None, 53 | 'MN': None, 54 | 'LN': None, 55 | 'LRSR': None, 56 | 'MCM': None, 57 | 'front-loading': { # Front-loading parameters (will be set after the SOC run) 58 | 'max_interventions': None, 59 | 'max_lambda': None 60 | } 61 | } 62 | 63 | 64 | def worker(policy, policy_params, n_sims, q_idx, net_idx, output_filename): 65 | """ 66 | Run multiple simulation of a given policy and save the maximum sum of 67 | control rate observed for each simulation 68 | """ 69 | graph = make_ebola_network(n_nodes=n_nodes, p_in=p_in, p_out=p_out) 70 | print(f'graph: {graph.number_of_nodes()} nodes, {graph.number_of_edges()} edges') 71 | 72 | init_event_list = sample_seeds(graph, delta=delta, method='data', 73 | max_date=start_day_str, verbose=False) 74 | 75 | res_dict = { 76 | 'max_u': list(), 77 | 'n_tre': list(), 78 | 'q_x': policy_params['q_x'], 79 | 'q_idx': q_idx, 80 | 'net_idx': net_idx, 81 | } 82 | 83 | for sim_idx in range(n_sims): 84 | 85 | start_time = time.time() 86 | 87 | sir_obj = SimulationSIR(graph, beta=beta, delta=delta, gamma=gamma, rho=rho, verbose=False) 88 | sir_obj.launch_epidemic( 89 | init_event_list=init_event_list, 90 | max_time=max_days, 91 | policy=policy, 92 | policy_dict=policy_params 93 | ) 94 | 95 | res_dict['max_u'].append(float(sir_obj.max_total_control_intensity)) 96 | res_dict['n_tre'].append(float(sir_obj.is_tre.sum())) 97 | 98 | run_time = time.time() - start_time 99 | 100 | print(f"Finished: q_x:{q_idx} ({policy_params['q_x']}) net:{net_idx+1} sim:{sim_idx+1}/{n_sims} in {run_time:2.f} sec") 101 | 102 | with open(output_filename, 'w') as f: 103 | json.dump(res_dict, f) 104 | 105 | if __name__ == "__main__": 106 | 107 | OUT_DIR = os.path.join(PROJECT_DIR, 'output', 'baseline-calibration-soc') 108 | if not os.path.exists(OUT_DIR): 109 | print(f"Create output directory: {OUT_DIR}") 110 | os.mkdir(OUT_DIR) 111 | 112 | Q_X_RANGE = [1.0, 10.0, 25.0, 50.0, 100.0, 250.0, 500.0, 750.0, 1000.0] 113 | 114 | NUM_NETS = 5 115 | NUM_SIMS = 5 116 | 117 | args_list = list() 118 | for q_idx, q_x in enumerate(Q_X_RANGE): 119 | for net_idx in range(NUM_NETS): 120 | 121 | policy_params = copy.deepcopy(DEFAULT_POLICY_PARAMS) 122 | policy_params['q_x'] = q_x 123 | 124 | output_filename = os.path.join(OUT_DIR, f"output-q{q_idx:d}-n{net_idx:d}.json") 125 | 126 | args_list.append(('SOC', policy_params, NUM_SIMS, q_idx, net_idx, output_filename)) 127 | 128 | n_procs = cpu_count()-1 129 | 130 | print(f"\nRun {len(args_list)} jobs on {n_procs} processes...\n") 131 | 132 | pool = Pool(n_procs) 133 | pool.starmap(worker, args_list) 134 | 135 | -------------------------------------------------------------------------------- /notebooks/Profile - Experiment.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "%load_ext autoreload\n", 10 | "%autoreload 2\n", 11 | "\n", 12 | "%load_ext line_profiler\n", 13 | "\n", 14 | "import sys\n", 15 | "if '../' not in sys.path:\n", 16 | " sys.path.append('../')" 17 | ] 18 | }, 19 | { 20 | "cell_type": "code", 21 | "execution_count": 2, 22 | "metadata": { 23 | "scrolled": true 24 | }, 25 | "outputs": [], 26 | "source": [ 27 | "import networkx as nx\n", 28 | "import numpy as np\n", 29 | "import argparse\n", 30 | "import joblib\n", 31 | "import os\n", 32 | "\n", 33 | "from experiment import Experiment" 34 | ] 35 | }, 36 | { 37 | "cell_type": "code", 38 | "execution_count": 3, 39 | "metadata": {}, 40 | "outputs": [ 41 | { 42 | "name": "stdout", 43 | "output_type": "stream", 44 | "text": [ 45 | "Network: 500 nodes, 10761 edges\n", 46 | "\n", 47 | "Choose set of initial infected seeds\n", 48 | "\n" 49 | ] 50 | } 51 | ], 52 | "source": [ 53 | "# Construct the adjacency matrix A of the propagation network\n", 54 | "net = nx.read_edgelist('../data/ebola/ebola_augmented_net_edge_list.csv')\n", 55 | "A = nx.adjacency_matrix(net).toarray().astype(float)\n", 56 | "n_nodes = net.number_of_nodes()\n", 57 | "n_edges = net.number_of_edges()\n", 58 | "print(f\"Network: {n_nodes:d} nodes, {n_edges:d} edges\")\n", 59 | "print()\n", 60 | "\n", 61 | "# Initial infections\n", 62 | "print('Choose set of initial infected seeds')\n", 63 | "infected = 10\n", 64 | "X_init = np.hstack(((np.ones(infected), np.zeros(n_nodes - infected))))\n", 65 | "X_init = np.random.permutation(X_init)\n", 66 | "print()" 67 | ] 68 | }, 69 | { 70 | "cell_type": "code", 71 | "execution_count": 4, 72 | "metadata": {}, 73 | "outputs": [], 74 | "source": [ 75 | "exp = Experiment(\n", 76 | " name='test_all',\n", 77 | " sim_dict={\n", 78 | " 'total_time': 10.00,\n", 79 | " 'trials_per_setting': 1\n", 80 | " },\n", 81 | " param_dict={\n", 82 | " 'beta': 6.0,\n", 83 | " 'gamma': 5.0,\n", 84 | " 'delta': 1.0,\n", 85 | " 'rho': 5.0,\n", 86 | " 'eta': 1.0\n", 87 | " },\n", 88 | " cost_dict={\n", 89 | " 'Qlam': 1.0,\n", 90 | " 'Qx': 400.0\n", 91 | " },\n", 92 | " policy_list=[\n", 93 | " 'SOC',\n", 94 | " ],\n", 95 | " baselines_dict={\n", 96 | " 'TR': 0.003,\n", 97 | " 'MN': 0.0007,\n", 98 | " 'LN': 0.0008,\n", 99 | " 'LRSR': 22.807,\n", 100 | " 'MCM': 22.807,\n", 101 | " 'FL_info': {'N': None, 'max_u': None},\n", 102 | " }\n", 103 | ")" 104 | ] 105 | }, 106 | { 107 | "cell_type": "code", 108 | "execution_count": 5, 109 | "metadata": {}, 110 | "outputs": [], 111 | "source": [ 112 | "from dynamics import SISDynamicalSystem" 113 | ] 114 | }, 115 | { 116 | "cell_type": "code", 117 | "execution_count": 6, 118 | "metadata": { 119 | "scrolled": true 120 | }, 121 | "outputs": [ 122 | { 123 | "name": "stdout", 124 | "output_type": "stream", 125 | "text": [ 126 | "=== Policy: SOC...\n", 127 | " - Trial 1/1\n" 128 | ] 129 | }, 130 | { 131 | "name": "stderr", 132 | "output_type": "stream", 133 | "text": [ 134 | " " 135 | ] 136 | }, 137 | { 138 | "name": "stdout", 139 | "output_type": "stream", 140 | "text": [ 141 | "*** KeyboardInterrupt exception caught in code being profiled." 142 | ] 143 | }, 144 | { 145 | "name": "stderr", 146 | "output_type": "stream", 147 | "text": [ 148 | "\r" 149 | ] 150 | } 151 | ], 152 | "source": [ 153 | "%lprun -f SISDynamicalSystem._getOptPolicy exp.run(A, X_init)" 154 | ] 155 | }, 156 | { 157 | "cell_type": "code", 158 | "execution_count": null, 159 | "metadata": {}, 160 | "outputs": [], 161 | "source": [] 162 | } 163 | ], 164 | "metadata": { 165 | "kernelspec": { 166 | "display_name": "Python 3", 167 | "language": "python", 168 | "name": "python3" 169 | }, 170 | "language_info": { 171 | "codemirror_mode": { 172 | "name": "ipython", 173 | "version": 3 174 | }, 175 | "file_extension": ".py", 176 | "mimetype": "text/x-python", 177 | "name": "python", 178 | "nbconvert_exporter": "python", 179 | "pygments_lexer": "ipython3", 180 | "version": "3.6.3" 181 | } 182 | }, 183 | "nbformat": 4, 184 | "nbformat_minor": 2 185 | } 186 | -------------------------------------------------------------------------------- /notebooks/0-tutorial/1-SIR-simulation.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# SIR Simulation" 8 | ] 9 | }, 10 | { 11 | "cell_type": "code", 12 | "execution_count": 1, 13 | "metadata": {}, 14 | "outputs": [], 15 | "source": [ 16 | "import sys\n", 17 | "if '../../' not in sys.path:\n", 18 | " sys.path.append('../../')\n", 19 | "\n", 20 | "import networkx as nx\n", 21 | " \n", 22 | "from lib.dynamics import SimulationSIR" 23 | ] 24 | }, 25 | { 26 | "cell_type": "code", 27 | "execution_count": 2, 28 | "metadata": {}, 29 | "outputs": [], 30 | "source": [ 31 | "G = nx.complete_graph(3)\n", 32 | "\n", 33 | "init_event_list = [((0, 'inf', 0), 0.0), ]\n", 34 | "\n", 35 | "beta = 1.0\n", 36 | "delta = 0.1\n", 37 | "\n", 38 | "max_days = 20" 39 | ] 40 | }, 41 | { 42 | "cell_type": "code", 43 | "execution_count": 3, 44 | "metadata": {}, 45 | "outputs": [], 46 | "source": [ 47 | "sir_obj = SimulationSIR(G, beta=beta, delta=delta, gamma=0.0, rho=0.0, verbose=False)\n", 48 | "sir_obj.launch_epidemic(init_event_list=init_event_list, max_time=max_days)" 49 | ] 50 | }, 51 | { 52 | "cell_type": "markdown", 53 | "metadata": {}, 54 | "source": [ 55 | "Have nodes been infected ?" 56 | ] 57 | }, 58 | { 59 | "cell_type": "code", 60 | "execution_count": 4, 61 | "metadata": {}, 62 | "outputs": [ 63 | { 64 | "data": { 65 | "text/plain": [ 66 | "array([ True, True, True])" 67 | ] 68 | }, 69 | "execution_count": 4, 70 | "metadata": {}, 71 | "output_type": "execute_result" 72 | } 73 | ], 74 | "source": [ 75 | "sir_obj.is_inf" 76 | ] 77 | }, 78 | { 79 | "cell_type": "markdown", 80 | "metadata": {}, 81 | "source": [ 82 | "Have nodes been recovered ?" 83 | ] 84 | }, 85 | { 86 | "cell_type": "code", 87 | "execution_count": 5, 88 | "metadata": {}, 89 | "outputs": [ 90 | { 91 | "data": { 92 | "text/plain": [ 93 | "array([False, True, False])" 94 | ] 95 | }, 96 | "execution_count": 5, 97 | "metadata": {}, 98 | "output_type": "execute_result" 99 | } 100 | ], 101 | "source": [ 102 | "sir_obj.is_rec" 103 | ] 104 | }, 105 | { 106 | "cell_type": "markdown", 107 | "metadata": {}, 108 | "source": [ 109 | "When were nodes infected ?" 110 | ] 111 | }, 112 | { 113 | "cell_type": "code", 114 | "execution_count": 6, 115 | "metadata": {}, 116 | "outputs": [ 117 | { 118 | "data": { 119 | "text/plain": [ 120 | "array([0. , 1.34044955, 0.64087974])" 121 | ] 122 | }, 123 | "execution_count": 6, 124 | "metadata": {}, 125 | "output_type": "execute_result" 126 | } 127 | ], 128 | "source": [ 129 | "sir_obj.inf_occured_at" 130 | ] 131 | }, 132 | { 133 | "cell_type": "markdown", 134 | "metadata": {}, 135 | "source": [ 136 | "When were nodes recovered ?" 137 | ] 138 | }, 139 | { 140 | "cell_type": "code", 141 | "execution_count": 7, 142 | "metadata": {}, 143 | "outputs": [ 144 | { 145 | "data": { 146 | "text/plain": [ 147 | "array([ inf, 2.06541615, inf])" 148 | ] 149 | }, 150 | "execution_count": 7, 151 | "metadata": {}, 152 | "output_type": "execute_result" 153 | } 154 | ], 155 | "source": [ 156 | "sir_obj.rec_occured_at" 157 | ] 158 | }, 159 | { 160 | "cell_type": "markdown", 161 | "metadata": {}, 162 | "source": [ 163 | "Who infected the node (-1 if source)?" 164 | ] 165 | }, 166 | { 167 | "cell_type": "code", 168 | "execution_count": 8, 169 | "metadata": {}, 170 | "outputs": [ 171 | { 172 | "data": { 173 | "text/plain": [ 174 | "array([-1, 2, 0])" 175 | ] 176 | }, 177 | "execution_count": 8, 178 | "metadata": {}, 179 | "output_type": "execute_result" 180 | } 181 | ], 182 | "source": [ 183 | "sir_obj.infector" 184 | ] 185 | }, 186 | { 187 | "cell_type": "markdown", 188 | "metadata": {}, 189 | "source": [ 190 | "List of seeds" 191 | ] 192 | }, 193 | { 194 | "cell_type": "code", 195 | "execution_count": 9, 196 | "metadata": {}, 197 | "outputs": [ 198 | { 199 | "data": { 200 | "text/plain": [ 201 | "array([ True, False, False])" 202 | ] 203 | }, 204 | "execution_count": 9, 205 | "metadata": {}, 206 | "output_type": "execute_result" 207 | } 208 | ], 209 | "source": [ 210 | "sir_obj.initial_seed" 211 | ] 212 | }, 213 | { 214 | "cell_type": "code", 215 | "execution_count": null, 216 | "metadata": {}, 217 | "outputs": [], 218 | "source": [] 219 | } 220 | ], 221 | "metadata": { 222 | "kernelspec": { 223 | "display_name": "Python 3", 224 | "language": "python", 225 | "name": "python3" 226 | }, 227 | "language_info": { 228 | "codemirror_mode": { 229 | "name": "ipython", 230 | "version": 3 231 | }, 232 | "file_extension": ".py", 233 | "mimetype": "text/x-python", 234 | "name": "python", 235 | "nbconvert_exporter": "python", 236 | "pygments_lexer": "ipython3", 237 | "version": "3.7.11" 238 | } 239 | }, 240 | "nbformat": 4, 241 | "nbformat_minor": 4 242 | } 243 | -------------------------------------------------------------------------------- /lib/ogata/experiment.py: -------------------------------------------------------------------------------- 1 | from dynamics import SISDynamicalSystem 2 | 3 | 4 | class Experiment: 5 | """ 6 | Class for easier handling of experiments. 7 | Directly incorporates running the dynamical system simulation. 8 | 9 | Attributes: 10 | ----------- 11 | policy_list : list 12 | List of policies to run for the experiment. `SOC` should be first to set front-loading 13 | parameters of baselines. 14 | sim_dict : dict 15 | Simulation parameters. Should have keys: 16 | - 'total_time': the time to run each simulation 17 | - 'trials_per_setting': the number of trials for each simulation 18 | param_dict : dict 19 | Model parameters. Should have keys: 20 | - 'beta': infection rate 21 | - 'gamma': reduction in infection rate 22 | - 'delta': spontaneous recovery rate 23 | - 'rho': treatement recovery rate 24 | - 'eta': exponential decay for SOC strategiy 25 | cost_dict : dict 26 | Costs of SOC strategiy. Should have keys: 27 | - 'Qlam': cost of treatement 28 | - 'Qx': cost of infection 29 | baselines_dict : dict 30 | Parameters for baseline strategies. Should have keys: 31 | - each policy with its corresponding scaling parameter 32 | - 'FL_info': a dict with keys: 33 | - 'N': Maximum number of treatement 34 | - 'max_u': Maximum intensity 35 | """ 36 | 37 | def __init__(self, name, policy_list, sim_dict, param_dict, cost_dict, baselines_dict): 38 | # Default policies to run 39 | self.policy_list = ['SOC'] 40 | # Default simulation settings 41 | self.sim_dict = { 42 | 'total_time': 10.00, 43 | 'trials_per_setting': 10 44 | } 45 | # Default model parameters 46 | self.param_dict = { 47 | 'beta': 6.0, 48 | 'gamma': 5.0, 49 | 'delta': 1.0, 50 | 'rho': 5.0, 51 | 'eta': 1.0 52 | } 53 | # Default loss function parameterization 54 | self.cost_dict = { 55 | 'Qlam': 1.0, 56 | 'Qx': 400.0 57 | } 58 | # Default proportional scaling for fair comparison 59 | self.baselines_dict = { 60 | 'TR': 0.003, 61 | 'MN': 0.0007, 62 | 'LN': 0.0008, 63 | 'LRSR': 22.807, 64 | 'MCM': 22.807, 65 | 'FL_info': {'N': None, 'max_u': None}, 66 | } 67 | # Experiment name 68 | self.name = name 69 | # Change defaults to given parameters 70 | self.update(policy_list=policy_list, 71 | sim_dict=sim_dict, 72 | param_dict=param_dict, 73 | cost_dict=cost_dict, 74 | baselines_dict=baselines_dict) 75 | 76 | def update(self, policy_list=None, sim_dict=None, param_dict=None, 77 | cost_dict=None, baselines_dict=None): 78 | """ 79 | Update dictionaries. 80 | """ 81 | if policy_list: 82 | assert policy_list[0] == 'SOC', "Strategy `SOC` must be run first to set FL info." 83 | self.policy_list = policy_list 84 | if sim_dict: 85 | self.sim_dict = sim_dict 86 | if param_dict: 87 | self.param_dict = param_dict 88 | if cost_dict: 89 | self.cost_dict = cost_dict 90 | if baselines_dict: 91 | self.baselines_dict = baselines_dict 92 | 93 | def update_fl_info(self, data): 94 | """ 95 | Update the `FL_info` dict based on the current data. `data` must the collected data from 96 | a run of a SISDynamicalSystem. 97 | """ 98 | assert data['info']['policy'] == 'SOC' 99 | # Extract the maximum value of the control intensity 100 | max_u = max([max([proc.value_at(t) for t in proc.arrival_times]) for proc in data['u']]) 101 | # Extract the number of treatements 102 | n_treatement = sum([proc.get_current_value() for proc in data['Nc']]) 103 | # Update the FL info dict 104 | if self.baselines_dict['FL_info']['N'] is None: 105 | self.baselines_dict['FL_info']['max_u'] = max_u 106 | self.baselines_dict['FL_info']['N'] = n_treatement 107 | elif n_treatement > self.baselines_dict['FL_info']['N']: 108 | self.baselines_dict['FL_info']['max_u'] = max_u 109 | self.baselines_dict['FL_info']['N'] = n_treatement 110 | 111 | def run(self, A, X_init): 112 | """ 113 | Run the experiment and return a summary as a list of dict. 114 | """ 115 | n_trials = self.sim_dict['trials_per_setting'] 116 | # Initialize the result object 117 | result = [{"dat": [], "name": policy} for policy in self.policy_list] 118 | # Simulate every requested policy 119 | for j, policy in enumerate(self.policy_list): 120 | print(f"=== Policy: {policy:s}...") 121 | # ...for many trials 122 | for tr in range(n_trials): 123 | print(f" - Trial {tr+1:d}/{n_trials}") 124 | # Simulate a trajectory of the SIS dynamical system under 125 | # various control strategies 126 | system = SISDynamicalSystem( 127 | X_init, A, self.param_dict, self.cost_dict) 128 | data = system.simulate_policy( 129 | policy, self.baselines_dict, self.sim_dict, plot=False) 130 | # Add the policy name to the info dict of the run 131 | data['info']['policy'] = policy 132 | # Add the FL_info parameters to the info dict of the run 133 | data['info']['FL_info'] = self.baselines_dict['FL_info'].copy() 134 | # Make sure the policy is the correct one and add the data to result 135 | assert result[j]['name'] == policy 136 | result[j]['dat'].append(data) 137 | # If policy is SOC, we update the FL info dict 138 | if policy == 'SOC': 139 | self.update_fl_info(data) 140 | return result 141 | -------------------------------------------------------------------------------- /notebooks/Evaluation - multi (old code version).ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "%load_ext autoreload\n", 10 | "%autoreload 2\n", 11 | "\n", 12 | "import sys\n", 13 | "if '../' not in sys.path:\n", 14 | " sys.path.append('../')" 15 | ] 16 | }, 17 | { 18 | "cell_type": "code", 19 | "execution_count": null, 20 | "metadata": {}, 21 | "outputs": [], 22 | "source": [ 23 | "import matplotlib.pyplot as plt\n", 24 | "import collections\n", 25 | "import argparse\n", 26 | "import joblib\n", 27 | "import os\n", 28 | "\n", 29 | "from analysis import Evaluation, MultipleEvaluations" 30 | ] 31 | }, 32 | { 33 | "cell_type": "code", 34 | "execution_count": null, 35 | "metadata": {}, 36 | "outputs": [], 37 | "source": [ 38 | "# OUTPUT_DIR = '../temp_pickles/'\n", 39 | "# saved = {\n", 40 | "# 0: 'test_all_qx1_Q_1_1_v0.pkl',\n", 41 | "# 1: 'test_all_qx2_Q_1_1_v0.pkl',\n", 42 | "# 2: 'test_all_qx3_Q_1_2_v0.pkl',\n", 43 | "# 3: 'test_all_qx4_Q_1_3_v0.pkl',\n", 44 | "# 4: 'test_all_qx5_Q_1_4_v0.pkl',\n", 45 | "# 5: 'test_all_qx6_Q_1_5_v0.pkl',\n", 46 | "# 6: 'test_all_qx7_Q_1_7_v0.pkl',\n", 47 | "# 7: 'test_all_qx8_Q_1_10_v0.pkl',\n", 48 | "# 8: 'test_all_qx9_Q_1_14_v0.pkl',\n", 49 | "# 9: 'test_all_qx10_Q_1_19_v0.pkl',\n", 50 | "# 10: 'test_all_qx11_Q_1_26_v0.pkl',\n", 51 | "# 11: 'test_all_qx12_Q_1_37_v0.pkl',\n", 52 | "# 12: 'test_all_qx13_Q_1_51_v0.pkl',\n", 53 | "# 13: 'test_all_qx14_Q_1_70_v0.pkl',\n", 54 | "# 14: 'test_all_qx15_Q_1_97_v0.pkl',\n", 55 | "# 15: 'test_all_qx16_Q_1_135_v0.pkl',\n", 56 | "# 16: 'test_all_qx17_Q_1_187_v0.pkl',\n", 57 | "# 17: 'test_all_qx18_Q_1_260_v0.pkl',\n", 58 | "# 18: 'test_all_qx19_Q_1_361_v0.pkl',\n", 59 | "# 19: 'test_all_qx20_Q_1_500_v0.pkl',\n", 60 | "# } \n", 61 | "# all_selected = list(range(20)) # select pickle files to import\n", 62 | "# multi_summary_from_dump = True\n", 63 | "\n", 64 | "\n", 65 | "# OUTPUT_DIR = '../temp_pickles/'\n", 66 | "# saved = {\n", 67 | "# 0: 'test_all_but_MCM_Q_1_300_v0.pkl'\n", 68 | "# }\n", 69 | "# all_selected = [0]\n", 70 | "# multi_summary_from_dump = False\n", 71 | "\n", 72 | "\n", 73 | "OUTPUT_DIR = '../temp_pickles/reprod_workshop_plots/'\n", 74 | "saved = {\n", 75 | " 0: 'test_all_qx0_Q_1_1_v0.pkl',\n", 76 | " 1: 'test_all_qx1_Q_1_10_v0.pkl',\n", 77 | " 2: 'test_all_qx2_Q_1_25_v0.pkl',\n", 78 | " 3: 'test_all_qx3_Q_1_50_v0.pkl',\n", 79 | " 4: 'test_all_qx4_Q_1_75_v0.pkl',\n", 80 | " 5: 'test_all_qx5_Q_1_100_v0.pkl',\n", 81 | " 6: 'test_all_qx6_Q_1_150_v0.pkl',\n", 82 | " 7: 'test_all_qx7_Q_1_200_v0.pkl',\n", 83 | " 8: 'test_all_qx8_Q_1_300_v0.pkl',\n", 84 | " 9: 'test_all_qx9_Q_1_400_v0.pkl',\n", 85 | " 10: 'test_all_qx10_Q_1_500_v0.pkl',\n", 86 | "} \n", 87 | "\n", 88 | "\n", 89 | "all_selected = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10] # select pickle files to import\n", 90 | "multi_summary_from_dump = True" 91 | ] 92 | }, 93 | { 94 | "cell_type": "code", 95 | "execution_count": null, 96 | "metadata": { 97 | "scrolled": false 98 | }, 99 | "outputs": [], 100 | "source": [ 101 | "\n", 102 | "# summary for multi setting comparison\n", 103 | "multi_summary = collections.defaultdict(dict)\n", 104 | "if not multi_summary_from_dump:\n", 105 | "\n", 106 | " '''Individual analyses'''\n", 107 | " for selected in all_selected:\n", 108 | "\n", 109 | " # to see graphs instead of saving them, comment out 'plt.switch_backend('agg')' from top of main.py\n", 110 | " # do not comment out when running on cluster\n", 111 | "\n", 112 | " print('Analyzing: {}'.format(saved[selected]))\n", 113 | "\n", 114 | " data = joblib.load(OUTPUT_DIR + saved[selected])\n", 115 | " filename = saved[selected]\n", 116 | " description = [d['name'] for d in data]\n", 117 | " dat = [d['dat'] for d in data]\n", 118 | " eval = Evaluation(dat, filename, description)\n", 119 | "\n", 120 | " multi_summary['Qs'][saved[selected]] = eval.data[0][0]['info']['Qx'][0]\n", 121 | "\n", 122 | " ''''''''''''''''''''''''''''''''''''''''''\n", 123 | "\n", 124 | " '''Individual analysis'''\n", 125 | "\n", 126 | "# eval.simulation_infection_plot(size_tup=(5.0, 3.7), granularity=0.001, save=True)\n", 127 | "\n", 128 | " # eval.infections_and_interventions_complete(save=True)\n", 129 | " # eval.simulation_treatment_plot(granularity=0.001, save=True)\n", 130 | " # eval.present_discounted_loss(plot=True, save=True)\n", 131 | "\n", 132 | " ''''''''''''''''''''''''''''''''''''''''''\n", 133 | "\n", 134 | " '''Compute Comparison analysis data'''\n", 135 | "\n", 136 | " summary_tup = eval.infections_and_interventions_complete(size_tup = (8, 5), save=True)\n", 137 | " plt.close()\n", 138 | " multi_summary['infections_and_interventions'][saved[selected]] = summary_tup\n", 139 | "\n", 140 | "# summary_tup = eval.summarize_interventions_and_intensities()\n", 141 | "# multi_summary['stats_intervention_intensities'][saved[selected]] = summary_tup\n", 142 | "\n", 143 | "\n", 144 | " ''''''''''''''''''''''''''''''''''''''''''\n", 145 | " # eval.debug()\n", 146 | "\n", 147 | "\n", 148 | " dum = (saved, all_selected, multi_summary)\n", 149 | " joblib.dump(dum, 'multi_comp_dump_{}'.format(saved[all_selected[-1]]))\n", 150 | "\n", 151 | "else:\n", 152 | "\n", 153 | " dum = joblib.load('multi_comp_dump_{}'.format(saved[all_selected[-1]]))\n", 154 | " saved = dum[0]\n", 155 | " all_selected = dum[1]\n", 156 | " multi_summary = dum[2]\n", 157 | "\n", 158 | "'''Comparative analysis'''\n", 159 | "multi_eval = MultipleEvaluations(saved, all_selected, multi_summary)\n", 160 | "multi_eval.compare_infections(size_tup=(5.0, 3.7), save=True)" 161 | ] 162 | }, 163 | { 164 | "cell_type": "code", 165 | "execution_count": null, 166 | "metadata": {}, 167 | "outputs": [], 168 | "source": [] 169 | } 170 | ], 171 | "metadata": { 172 | "kernelspec": { 173 | "display_name": "Python 3", 174 | "language": "python", 175 | "name": "python3" 176 | }, 177 | "language_info": { 178 | "codemirror_mode": { 179 | "name": "ipython", 180 | "version": 3 181 | }, 182 | "file_extension": ".py", 183 | "mimetype": "text/x-python", 184 | "name": "python", 185 | "nbconvert_exporter": "python", 186 | "pygments_lexer": "ipython3", 187 | "version": "3.6.3" 188 | } 189 | }, 190 | "nbformat": 4, 191 | "nbformat_minor": 2 192 | } 193 | -------------------------------------------------------------------------------- /data/ebola/ebola_base_graph.json: -------------------------------------------------------------------------------- 1 | {"directed": false, "multigraph": false, "graph": {}, "nodes": [{"size": 0.011869075548713299, "country": "Guinea", "pos": [-8.633333, 8.683333], "id": "BEYLA"}, {"size": 0.021336813687757163, "country": "Guinea", "pos": [-9.311828, 10.382789], "id": "KANKAN"}, {"size": 0.013960081568176071, "country": "Guinea", "pos": [-9.007367, 9.27026], "id": "KEROUANE"}, {"size": 0.011852956301102379, "country": "Guinea", "pos": [-8.533653, 7.802235], "id": "LOLA"}, {"size": 0.02647416617930257, "country": "Guinea", "pos": [-9.472824, 8.538294], "id": "MACENTA"}, {"size": 0.02132983027484561, "country": "Guinea", "pos": [-8.82525, 7.747836], "id": "NZEREKORE"}, {"size": 0.010414995052060626, "country": "Guinea", "pos": [-14.039161, 10.180825], "id": "BOFFA"}, {"size": 0.02518663529743097, "country": "Guinea", "pos": [-14.100133, 11.186467], "id": "BOKE"}, {"size": 0.009578516113723061, "country": "Guinea", "pos": [-13.514774, 9.790735], "id": "DUBREKA"}, {"size": 0.006804140092428818, "country": "Guinea", "pos": [-13.584187, 10.367454], "id": "FRIA"}, {"size": 0.014077460303004547, "country": "Guinea", "pos": [-13.029933, 10.908936], "id": "TELIMELE"}, {"size": 0.08271225999839381, "country": "Guinea", "pos": [-13.578401, 9.641185], "id": "CONAKRY"}, {"size": 0.019514095086245867, "country": "Guinea", "pos": [-13.387612, 9.708636], "id": "COYAH"}, {"size": 0.020030293662557565, "country": "Guinea", "pos": [-13.090435, 9.434471], "id": "FORECARIAH"}, {"size": 0.02644800229668188, "country": "Guinea", "pos": [-12.862989, 10.040672], "id": "KINDIA"}, {"size": 0.008677273195370112, "country": "Guinea", "pos": [-11.110785, 10.729781], "id": "DABOLA"}, {"size": 0.009701108492437561, "country": "Guinea", "pos": [-10.715423, 11.289951], "id": "DINGUIRAYE"}, {"size": 0.010464596416370785, "country": "Guinea", "pos": [-10.749247, 10.045102], "id": "FARANAH"}, {"size": 0.010505683756720278, "country": "Guinea", "pos": [-9.885059, 10.648923], "id": "KOUROUSSA"}, {"size": 0.0101614875970521, "country": "Guinea", "pos": [-12.24907, 10.686818], "id": "DALABA"}, {"size": 0.014397979823006835, "country": "Guinea", "pos": [-12.397943, 11.057462], "id": "PITA"}, {"size": 0.008670768098411403, "country": "Guinea", "pos": [-11.664139, 11.446422], "id": "TOUGUE"}, {"size": 0.02310796982955163, "country": "Guinea", "pos": [-9.17883, 11.414811], "id": "SIGUIRI"}, {"size": 0.03835979145998435, "country": "Guinea", "pos": [-10.131116, 8.564969], "id": "GUECKEDOU"}, {"size": 0.01577672555708622, "country": "Guinea", "pos": [-10.114318, 9.191454], "id": "KISSIDOUGO"}, {"size": 0.015573728266698261, "country": "Sierra Leone", "pos": [-11.524805, 9.530862], "id": "KOINADUGU"}, {"size": 0.024131231147475667, "country": "Sierra Leone", "pos": [-12.163272, 9.247584], "id": "BOMBALI"}, {"size": 0.014629245586208008, "country": "Sierra Leone", "pos": [-12.917652, 9.126166], "id": "KAMBIA"}, {"size": 0.01777718636667519, "country": "Liberia", "pos": [-9.723267, 8.191118], "id": "LOFA"}, {"size": 0.01932817367537452, "country": "Sierra Leone", "pos": [-10.571809, 8.28022], "id": "KAILAHUN"}, {"size": 0.023441882196233617, "country": "Sierra Leone", "pos": [-10.89031, 8.766329], "id": "KONO"}, {"size": 0.030186376289342502, "country": "Liberia", "pos": [-8.660059, 6.842761], "id": "NIMBA"}, {"size": 0.01567211785819874, "country": "Guinea", "pos": [-9.259157, 7.569628], "id": "YOMOU"}, {"size": 0.02026538595338151, "country": "Liberia", "pos": [-9.367308, 6.829502], "id": "BONG"}, {"size": 0.006962223514844512, "country": "Liberia", "pos": [-10.845147, 6.756293], "id": "BOMI"}, {"size": 0.005346137404962805, "country": "Liberia", "pos": [-10.08073, 7.495264], "id": "GBARPOLU"}, {"size": 0.007866671150081531, "country": "Liberia", "pos": [-11.071176, 7.046776], "id": "GRAND_CAPE_MOUNT"}, {"size": 0.05617835215658074, "country": "Liberia", "pos": [-10.529611, 6.552581], "id": "MONTSERRADO"}, {"size": 0.012462569983004957, "country": "Liberia", "pos": [-9.812493, 6.230845], "id": "GRAND_BASSA"}, {"size": 0.01646090869603753, "country": "Liberia", "pos": [-10.30489, 6.515187], "id": "MARGIBI"}, {"size": 0.031428036671336146, "country": "Sierra Leone", "pos": [-11.195717, 7.863215], "id": "KENEMA"}, {"size": 0.004250219893801422, "country": "Liberia", "pos": [-9.456155, 5.902533], "id": "RIVERCESS"}, {"size": 0.01336012986852099, "country": "Sierra Leone", "pos": [-11.721064, 7.356299], "id": "PUJEHUN"}, {"size": 0.007874467700112926, "country": "Liberia", "pos": [-8.221298, 5.922208], "id": "GRAND_GEDEH"}, {"size": 0.004880066340509396, "country": "Liberia", "pos": [-7.87216, 5.260489], "id": "RIVER_GEE"}, {"size": 0.005959960267250429, "country": "Liberia", "pos": [-8.660059, 5.49871], "id": "SINOE"}, {"size": 0.0035883741098480327, "country": "Liberia", "pos": [-8.221298, 4.761386], "id": "GRAND_KRU"}, {"size": 0.006946486919995869, "country": "Liberia", "pos": [-7.74167, 4.725888], "id": "MARYLAND"}, {"size": 0.027871757561924098, "country": "Sierra Leone", "pos": [-11.471, 7.9552], "id": "BO"}, {"size": 0.007542755586814035, "country": "Sierra Leone", "pos": [-12.503992, 7.525703], "id": "BONTHE"}, {"size": 0.01550135906303262, "country": "Sierra Leone", "pos": [-12.435192, 8.162051], "id": "MOYAMBA"}, {"size": 0.02100601037476868, "country": "Sierra Leone", "pos": [-11.797961, 8.738942], "id": "TONKOLILI"}, {"size": 0.02675158943195341, "country": "Sierra Leone", "pos": [-12.785352, 8.768689], "id": "PORT_LOKO"}, {"size": 0.05398297759620692, "country": "Sierra Leone", "pos": [-13.035694, 8.311498], "id": "WESTERN"}, {"size": 0.011360912680409417, "country": "Guinea", "pos": [-12.297718, 12.074294], "id": "MALI"}], "links": [{"source": "BEYLA", "target": "KANKAN"}, {"source": "BEYLA", "target": "KEROUANE"}, {"source": "BEYLA", "target": "LOLA"}, {"source": "BEYLA", "target": "MACENTA"}, {"source": "BEYLA", "target": "NZEREKORE"}, {"source": "KANKAN", "target": "KEROUANE"}, {"source": "KANKAN", "target": "KISSIDOUGO"}, {"source": "KANKAN", "target": "KOUROUSSA"}, {"source": "KANKAN", "target": "SIGUIRI"}, {"source": "KEROUANE", "target": "KISSIDOUGO"}, {"source": "KEROUANE", "target": "MACENTA"}, {"source": "LOLA", "target": "NZEREKORE"}, {"source": "LOLA", "target": "NIMBA"}, {"source": "MACENTA", "target": "GUECKEDOU"}, {"source": "MACENTA", "target": "KISSIDOUGO"}, {"source": "MACENTA", "target": "NZEREKORE"}, {"source": "MACENTA", "target": "YOMOU"}, {"source": "MACENTA", "target": "LOFA"}, {"source": "NZEREKORE", "target": "YOMOU"}, {"source": "NZEREKORE", "target": "NIMBA"}, {"source": "BOFFA", "target": "BOKE"}, {"source": "BOFFA", "target": "DUBREKA"}, {"source": "BOFFA", "target": "FRIA"}, {"source": "BOFFA", "target": "TELIMELE"}, {"source": "BOKE", "target": "TELIMELE"}, {"source": "DUBREKA", "target": "CONAKRY"}, {"source": "DUBREKA", "target": "COYAH"}, {"source": "DUBREKA", "target": "FRIA"}, {"source": "DUBREKA", "target": "KINDIA"}, {"source": "DUBREKA", "target": "TELIMELE"}, {"source": "FRIA", "target": "TELIMELE"}, {"source": "TELIMELE", "target": "KINDIA"}, {"source": "TELIMELE", "target": "PITA"}, {"source": "CONAKRY", "target": "COYAH"}, {"source": "COYAH", "target": "FORECARIAH"}, {"source": "COYAH", "target": "KINDIA"}, {"source": "FORECARIAH", "target": "KINDIA"}, {"source": "FORECARIAH", "target": "BOMBALI"}, {"source": "FORECARIAH", "target": "KAMBIA"}, {"source": "KINDIA", "target": "DALABA"}, {"source": "KINDIA", "target": "PITA"}, {"source": "KINDIA", "target": "BOMBALI"}, {"source": "DABOLA", "target": "DINGUIRAYE"}, {"source": "DABOLA", "target": "FARANAH"}, {"source": "DABOLA", "target": "KOUROUSSA"}, {"source": "DABOLA", "target": "DALABA"}, {"source": "DINGUIRAYE", "target": "KOUROUSSA"}, {"source": "DINGUIRAYE", "target": "SIGUIRI"}, {"source": "DINGUIRAYE", "target": "TOUGUE"}, {"source": "FARANAH", "target": "GUECKEDOU"}, {"source": "FARANAH", "target": "KISSIDOUGO"}, {"source": "FARANAH", "target": "KOUROUSSA"}, {"source": "FARANAH", "target": "KOINADUGU"}, {"source": "FARANAH", "target": "DALABA"}, {"source": "KOUROUSSA", "target": "KISSIDOUGO"}, {"source": "KOUROUSSA", "target": "SIGUIRI"}, {"source": "DALABA", "target": "PITA"}, {"source": "DALABA", "target": "TOUGUE"}, {"source": "PITA", "target": "MALI"}, {"source": "TOUGUE", "target": "MALI"}, {"source": "GUECKEDOU", "target": "KISSIDOUGO"}, {"source": "GUECKEDOU", "target": "LOFA"}, {"source": "GUECKEDOU", "target": "KAILAHUN"}, {"source": "GUECKEDOU", "target": "KOINADUGU"}, {"source": "GUECKEDOU", "target": "KONO"}, {"source": "KOINADUGU", "target": "BOMBALI"}, {"source": "KOINADUGU", "target": "KONO"}, {"source": "KOINADUGU", "target": "TONKOLILI"}, {"source": "BOMBALI", "target": "KAMBIA"}, {"source": "BOMBALI", "target": "PORT_LOKO"}, {"source": "BOMBALI", "target": "TONKOLILI"}, {"source": "KAMBIA", "target": "PORT_LOKO"}, {"source": "LOFA", "target": "YOMOU"}, {"source": "LOFA", "target": "BONG"}, {"source": "LOFA", "target": "GBARPOLU"}, {"source": "LOFA", "target": "KAILAHUN"}, {"source": "KAILAHUN", "target": "GBARPOLU"}, {"source": "KAILAHUN", "target": "KENEMA"}, {"source": "KAILAHUN", "target": "KONO"}, {"source": "KONO", "target": "KENEMA"}, {"source": "KONO", "target": "TONKOLILI"}, {"source": "NIMBA", "target": "YOMOU"}, {"source": "NIMBA", "target": "BONG"}, {"source": "NIMBA", "target": "GRAND_BASSA"}, {"source": "NIMBA", "target": "GRAND_GEDEH"}, {"source": "NIMBA", "target": "RIVERCESS"}, {"source": "NIMBA", "target": "SINOE"}, {"source": "YOMOU", "target": "BONG"}, {"source": "BONG", "target": "BOMI"}, {"source": "BONG", "target": "GBARPOLU"}, {"source": "BONG", "target": "GRAND_BASSA"}, {"source": "BONG", "target": "MARGIBI"}, {"source": "BONG", "target": "MONTSERRADO"}, {"source": "BOMI", "target": "GBARPOLU"}, {"source": "BOMI", "target": "GRAND_CAPE_MOUNT"}, {"source": "BOMI", "target": "MONTSERRADO"}, {"source": "GBARPOLU", "target": "GRAND_CAPE_MOUNT"}, {"source": "GBARPOLU", "target": "MONTSERRADO"}, {"source": "GBARPOLU", "target": "KENEMA"}, {"source": "GRAND_CAPE_MOUNT", "target": "KENEMA"}, {"source": "GRAND_CAPE_MOUNT", "target": "PUJEHUN"}, {"source": "MONTSERRADO", "target": "MARGIBI"}, {"source": "GRAND_BASSA", "target": "MARGIBI"}, {"source": "GRAND_BASSA", "target": "RIVERCESS"}, {"source": "KENEMA", "target": "BO"}, {"source": "KENEMA", "target": "PUJEHUN"}, {"source": "KENEMA", "target": "TONKOLILI"}, {"source": "RIVERCESS", "target": "SINOE"}, {"source": "RIVERCESS", "target": "GRAND_GEDEH"}, {"source": "PUJEHUN", "target": "BO"}, {"source": "PUJEHUN", "target": "BONTHE"}, {"source": "GRAND_GEDEH", "target": "RIVER_GEE"}, {"source": "GRAND_GEDEH", "target": "SINOE"}, {"source": "RIVER_GEE", "target": "GRAND_KRU"}, {"source": "RIVER_GEE", "target": "MARYLAND"}, {"source": "RIVER_GEE", "target": "SINOE"}, {"source": "SINOE", "target": "GRAND_KRU"}, {"source": "GRAND_KRU", "target": "MARYLAND"}, {"source": "BO", "target": "BONTHE"}, {"source": "BO", "target": "MOYAMBA"}, {"source": "BO", "target": "TONKOLILI"}, {"source": "BONTHE", "target": "MOYAMBA"}, {"source": "MOYAMBA", "target": "PORT_LOKO"}, {"source": "MOYAMBA", "target": "TONKOLILI"}, {"source": "MOYAMBA", "target": "WESTERN"}, {"source": "TONKOLILI", "target": "PORT_LOKO"}, {"source": "PORT_LOKO", "target": "WESTERN"}]} -------------------------------------------------------------------------------- /script_single_job.py: -------------------------------------------------------------------------------- 1 | import random 2 | import numpy as np 3 | import pandas as pd 4 | import networkx as nx 5 | import argparse 6 | import os 7 | import json 8 | import sys 9 | 10 | from lib.graph_generation import make_ebola_network 11 | from lib.dynamics import SimulationSIR, PriorityQueue 12 | from lib.dynamics import sample_seeds 13 | from lib.settings import DATA_DIR 14 | 15 | 16 | def update_fl_info(param_dict, sir_obj): 17 | """ 18 | Update the front-loading parameters dict based on the current simulation to match the maximum 19 | """ 20 | assert sir_obj.policy == 'SOC', "`update_fl_info` must be applied on `SOC` policy" 21 | # Extract the maximum value of the control intensity 22 | max_u = sir_obj.max_control_intensity 23 | # Extract the number of treatements 24 | n_treatement = np.sum(sir_obj.is_tre) 25 | # Extract front-loading dict from param_dict 26 | d = param_dict['simulation']['policy_params']['front-loading'] 27 | # Update the FL info dict 28 | if d['max_lambda'] is None: 29 | d['max_lambda'] = max_u 30 | d['max_interventions'] = n_treatement 31 | elif n_treatement > d['max_interventions']: 32 | d['max_lambda'] = max_u 33 | d['max_interventions'] = n_treatement 34 | 35 | 36 | def run(exp_dir, param_filename, output_filename, net_idx, stdout=None, stderr=None, verbose=False): 37 | """ 38 | Run a single SIR simulation based on the parameters in `param_filename` inside directory 39 | `exp_dir` and output a summary into `output_filename`. Uses the random seed at index `net_idx` 40 | to generate a network. 41 | 42 | stdout and stderr can be redirected to files `stdout` and `stderr`. 43 | 44 | The parameter file is supposed to be a json file with the following format: 45 | ``` 46 | { 47 | 48 | network : { (Network model parameters) 49 | n_nodes : int 50 | Number of nodes desired in the network 51 | p_in : float 52 | Intra-district probability 53 | p_out : dict of float 54 | Inter-district probability per country, keyed by country name, with the additional 55 | key 'inter-country' for between-country edges 56 | seed_list : list. List of random seeds for reproducibility 57 | }, 58 | 59 | simulation : { (SIR simulation parameters) 60 | start_day_str : str 61 | Starting day of the simulation, formated as 'YYYY-MM-DD'. Used to sample seeds 62 | from the Ebola dataset. 63 | end_day_str : str 64 | Ending day of the simulation. 65 | sir_params : { (Parameters of the SIR model) 66 | beta : float 67 | Infection rate 68 | delta : float 69 | Recovery rate (spontaneous) 70 | gamma : float 71 | Reduction of infectivity under treatement 72 | rho : float 73 | Recovery rate under treatement 74 | }, 75 | policy_name : str 76 | Name of the Policy. 77 | policy_params : { (Parameters of the Policy) 78 | (depend on the policy) 79 | } 80 | }, 81 | 82 | job_type : str (optional) 83 | One the predefined job types. By default, perform a standard simulation following all 84 | the given parameters. Other job types are the following: 85 | - 'stop_after_seeds': 86 | Only perform the simulation on the seeds ego-network and stop once all seeds are 87 | recovered or once their neighbors are all infected. This job is performed to assess 88 | the basic reproduction number of the epidemic given the current parameters. 89 | 90 | } 91 | ``` 92 | 93 | The output file contains: 94 | - the intial seed events, 95 | - the infection time of each node, 96 | - the infector of each node, 97 | - the recovery time of each node, 98 | - the district of each node. 99 | 100 | """ 101 | 102 | if stdout is not None: 103 | sys.stdout = open(stdout, 'w') 104 | if stderr is not None: 105 | sys.stderr = open(stderr, 'w') 106 | 107 | # Load parameters from file 108 | param_filename_full = os.path.join(exp_dir, param_filename) 109 | if not os.path.exists(param_filename_full): 110 | raise FileNotFoundError('Input file `{:s}` not found.'.format(param_filename_full)) 111 | with open(param_filename_full, 'r') as param_file: 112 | param_dict = json.load(param_file) 113 | 114 | print('\nExperiment parameters') 115 | print('=====================') 116 | print(f' exp_dir = {exp_dir:s}') 117 | print(f' param_filename = {param_filename:s}') 118 | print(f'output_filename = {output_filename:s}', flush=True) 119 | 120 | # Init output dict 121 | output_dict = {} 122 | output_dict['simulation_list'] = list() 123 | 124 | # Generate network of districts 125 | # ============================= 126 | 127 | print('\nGENERATE NETWORK') 128 | print('================') 129 | 130 | # Extract random seed from list 131 | net_seed = param_dict['network']['seed_list'][net_idx] 132 | param_dict['network']['seed'] = net_seed 133 | del param_dict['network']['seed_list'] 134 | 135 | print('\nNetwork parameters') 136 | print(f" - n_nodes = {param_dict['network']['n_nodes']:d}") 137 | print(f" - p_in = {param_dict['network']['p_in']:.2e}") 138 | print(f" - p_out = {param_dict['network']['p_out']}") 139 | print(f" - seed = {net_seed:d}") 140 | 141 | graph = make_ebola_network(**param_dict['network']) 142 | 143 | print('\nGraph generated') 144 | print(f" - {graph.number_of_nodes():d} nodes") 145 | print(f" - {graph.number_of_edges():d} edges", flush=True) 146 | 147 | # Run simulation 148 | # ============== 149 | 150 | print('\nSIMULATION') 151 | print('==========') 152 | 153 | start_day_str = param_dict['simulation']['start_day_str'] 154 | end_day_str = param_dict['simulation']['end_day_str'] 155 | sim_timedelta = pd.to_datetime(end_day_str) - pd.to_datetime(start_day_str) 156 | max_time = sim_timedelta.days 157 | 158 | init_seed_method = param_dict['simulation']['init_seed_method'] 159 | init_seed_num = param_dict['simulation'].get('n_seeds') 160 | 161 | print('\nSimulation parameters') 162 | print(f' - start day: {start_day_str}') 163 | print(f' - end day: {end_day_str}') 164 | print(f' - number of days to simulate: {max_time}') 165 | print(f" - init seed method: {init_seed_method}") 166 | print(f" - init seed num: {init_seed_num}") 167 | 168 | print('\nEpidemic parameters') 169 | for key, val in param_dict['simulation']['sir_params'].items(): 170 | print(f' - {key:s}: {val:.2e}') 171 | 172 | print(f"\nPolicy name: {param_dict['simulation']['policy_name']:s}") 173 | 174 | print('\nPolicy parameters') 175 | for key, val in param_dict['simulation']['policy_params'].items(): 176 | if key == 'front-loading': 177 | continue 178 | elif isinstance(val, float): 179 | print(f' - {key:s}: {val:.2e}') 180 | else: 181 | print(f' - {key:s}: {val:s}') 182 | 183 | # Sample initial infected seeds at time t=0 184 | delta = param_dict['simulation']['sir_params']['delta'] 185 | 186 | init_event_list = sample_seeds(graph, delta=delta, 187 | method=init_seed_method, 188 | n_seeds=init_seed_num, 189 | max_date=start_day_str, 190 | verbose=verbose) 191 | 192 | # Set default stopping criteria 193 | stop_criteria = None 194 | 195 | # Modify parameters for special job types 196 | if param_dict.get('job_type') == 'stop_after_seeds': 197 | # Extract ego-network of seeds 198 | seed_node_list = np.array(list(set([event[0] for event, _ in init_event_list]))) 199 | seed_neighbs_list = np.hstack([list(graph.neighbors(u)) for u in seed_node_list]) 200 | graph = nx.subgraph(graph, np.hstack((seed_node_list, seed_neighbs_list))) 201 | 202 | # Define stop_criteria 203 | def stop_criteria(sir_obj): 204 | seed_node_indices = np.array([sir_obj.node_to_idx[u] for u in seed_node_list]) 205 | seed_neighbs_indices = np.array([sir_obj.node_to_idx[u] for u in seed_neighbs_list]) 206 | all_seeds_rec = np.all(sir_obj.is_rec[seed_node_indices]) 207 | all_neighbors_inf = np.all(sir_obj.is_inf[seed_neighbs_indices]) 208 | return all_seeds_rec or all_neighbors_inf 209 | 210 | print('\nRun a single simulation for each policy:', flush=True) 211 | 212 | # Simulate every requested policy 213 | for j, policy in enumerate(param_dict['simulation']['policy_list']): 214 | print(f"=== Policy: {policy:s}...") 215 | # ...for many trials 216 | for i in range(param_dict['simulation']['num_sims']): 217 | print(f" - Simulation {i+1:d}/{param_dict['simulation']['num_sims']}") 218 | 219 | # Run SIR simulation 220 | # ------------------ 221 | 222 | # Reinitialize random seed for simulation 223 | random.seed(None) 224 | seed = random.randint(0, 2**32-1) 225 | random.seed(seed) 226 | print(f' - Random seed: {seed}') 227 | # Add to output dict 228 | 229 | 230 | sir_obj = SimulationSIR( 231 | graph, **param_dict['simulation']['sir_params'], verbose=verbose) 232 | sir_obj.launch_epidemic( 233 | init_event_list=init_event_list, max_time=max_time, 234 | policy=param_dict['simulation']['policy_name'], 235 | policy_dict=param_dict['simulation']['policy_params'], 236 | stop_criteria=stop_criteria) 237 | 238 | # Updat the front-loading parameters for subsequent simulations 239 | if policy == 'SOC': 240 | update_fl_info(param_dict, sir_obj) 241 | print(' - Updated front-loading parameters:', 242 | param_dict['simulation']['policy_params']['front-loading']) 243 | 244 | # Post-simulation summarization and set output 245 | # -------------------------------------------- 246 | 247 | # Init output dict for this simulation 248 | out_sim_dict = {} 249 | 250 | # Add seed 251 | out_sim_dict['simulation_seed'] = seed 252 | 253 | # Add init_event_list to output dict 254 | # Format init_event_list node names into int to make the object json-able 255 | for i, (e, t) in enumerate(init_event_list): 256 | init_event_list[i] = ((int(e[0]), e[1], int(e[2]) if e[2] is not None else None), float(t)) 257 | out_sim_dict['init_event_list'] = init_event_list 258 | 259 | # Add other info on the events of each node 260 | out_sim_dict['inf_occured_at'] = sir_obj.inf_occured_at.tolist() 261 | out_sim_dict['rec_occured_at'] = sir_obj.rec_occured_at.tolist() 262 | out_sim_dict['infector'] = sir_obj.infector.tolist() 263 | out_sim_dict['node_idx_pairs'] = [(int(u), u_idx) for u, u_idx in sir_obj.node_to_idx.items()] 264 | 265 | country_list = np.zeros(sir_obj.n_nodes, dtype=object) 266 | for u, d in sir_obj.G.nodes(data=True): 267 | country_list[sir_obj.node_to_idx[u]] = d['country'] 268 | out_sim_dict['country'] = country_list.tolist() 269 | 270 | node_district_arr = np.zeros(sir_obj.n_nodes, dtype='object') 271 | for node, data in sir_obj.G.nodes(data=True): 272 | node_idx = sir_obj.node_to_idx[node] 273 | node_district_arr[node_idx] = data['district'] 274 | out_sim_dict['district'] = node_district_arr.tolist() 275 | 276 | output_dict['simulation_list'].append(out_sim_dict) 277 | 278 | print('\n\nSave results...') 279 | 280 | with open(os.path.join(exp_dir, output_filename), 'w') as output_file: 281 | json.dump(output_dict, output_file) 282 | 283 | # Log that the run is finished 284 | print('\n\nFinished.') 285 | 286 | 287 | if __name__ == "__main__": 288 | parser = argparse.ArgumentParser() 289 | parser.add_argument('-d', '--dir', dest='dir', type=str, 290 | required=True, help="Working directory") 291 | 292 | parser.add_argument('-i', '--netidx', dest='net_idx', type=int, 293 | required=True, help="Network index to use (in parameter file)") 294 | 295 | parser.add_argument('-p', '--params', dest='param_filename', type=str, 296 | required=False, default='params.json', 297 | help="Input parameter file (JSON)") 298 | parser.add_argument('-o', '--outfile', dest='output_filename', type=str, 299 | required=False, default='output.json', 300 | help="Output file (JSON)") 301 | 302 | parser.add_argument('-v', '--verbose', dest='verbose', action="store_true", 303 | required=False, default=False, 304 | help="Print behavior") 305 | args = parser.parse_args() 306 | 307 | run( 308 | exp_dir=args.dir, 309 | param_filename=args.param_filename, 310 | output_filename=args.output_filename, 311 | net_idx=args.net_idx, 312 | verbose=args.verbose 313 | ) 314 | -------------------------------------------------------------------------------- /notebooks/2-calibration/2-baseline-calibration.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Calibration of baseline control parameters\n", 8 | "\n", 9 | "---" 10 | ] 11 | }, 12 | { 13 | "cell_type": "code", 14 | "execution_count": 1, 15 | "metadata": {}, 16 | "outputs": [], 17 | "source": [ 18 | "%load_ext autoreload\n", 19 | "%autoreload 2" 20 | ] 21 | }, 22 | { 23 | "cell_type": "code", 24 | "execution_count": 2, 25 | "metadata": {}, 26 | "outputs": [], 27 | "source": [ 28 | "from matplotlib import pyplot as plt\n", 29 | "%matplotlib inline\n", 30 | "\n", 31 | "import os\n", 32 | "import json\n", 33 | "import copy\n", 34 | "import itertools\n", 35 | "from collections import Counter, defaultdict\n", 36 | "import pandas as pd\n", 37 | "import networkx as nx\n", 38 | "import numpy as np\n", 39 | "from multiprocessing import cpu_count, Pool\n", 40 | "\n", 41 | "from lib.graph_generation import make_ebola_network\n", 42 | "from lib.dynamics import SimulationSIR, PriorityQueue\n", 43 | "from lib.dynamics import sample_seeds\n", 44 | "from lib.settings import DATA_DIR\n", 45 | "from lib import metrics" 46 | ] 47 | }, 48 | { 49 | "cell_type": "markdown", 50 | "metadata": {}, 51 | "source": [ 52 | "---" 53 | ] 54 | }, 55 | { 56 | "cell_type": "markdown", 57 | "metadata": {}, 58 | "source": [ 59 | "## 1. Set simulation parameters\n", 60 | "\n", 61 | "Set simulation period" 62 | ] 63 | }, 64 | { 65 | "cell_type": "code", 66 | "execution_count": 3, 67 | "metadata": {}, 68 | "outputs": [], 69 | "source": [ 70 | "start_day_str = '2014-01-01'\n", 71 | "end_day_str = '2014-04-01'\n", 72 | "max_timedelta = pd.to_datetime(end_day_str) - pd.to_datetime(start_day_str)\n", 73 | "max_days = max_timedelta.days" 74 | ] 75 | }, 76 | { 77 | "cell_type": "markdown", 78 | "metadata": {}, 79 | "source": [ 80 | "Set SIR infection and recovery rates" 81 | ] 82 | }, 83 | { 84 | "cell_type": "code", 85 | "execution_count": 4, 86 | "metadata": {}, 87 | "outputs": [], 88 | "source": [ 89 | "beta = 1 / 15.3\n", 90 | "delta = 1 / 11.4\n", 91 | "\n", 92 | "gamma = beta\n", 93 | "rho = 0.0" 94 | ] 95 | }, 96 | { 97 | "cell_type": "markdown", 98 | "metadata": {}, 99 | "source": [ 100 | "Set the network parameters." 101 | ] 102 | }, 103 | { 104 | "cell_type": "code", 105 | "execution_count": 5, 106 | "metadata": {}, 107 | "outputs": [], 108 | "source": [ 109 | "n_nodes = 8000\n", 110 | "p_in = 0.01\n", 111 | "p_out = {\n", 112 | " 'Guinea': 0.00215,\n", 113 | " 'Liberia': 0.00300, \n", 114 | " 'Sierra Leone': 0.00315, \n", 115 | " 'inter-country': 0.0019\n", 116 | "}" 117 | ] 118 | }, 119 | { 120 | "cell_type": "markdown", 121 | "metadata": {}, 122 | "source": [ 123 | "Set the control parameters." 124 | ] 125 | }, 126 | { 127 | "cell_type": "code", 128 | "execution_count": 6, 129 | "metadata": {}, 130 | "outputs": [], 131 | "source": [ 132 | "DEFAULT_POLICY_PARAMS = {\n", 133 | " # SOC parameters\n", 134 | " 'eta': 1.0, # SOC exponential decay\n", 135 | " 'q_x': None, # SOC infection cost\n", 136 | " 'q_lam': 1.0, # SOC recovery cost\n", 137 | " 'lpsolver': 'cvxopt', # SOC linear progam solver\n", 138 | " \n", 139 | " # Scaling of baseline methods\n", 140 | " 'TR' : None,\n", 141 | " 'MN' : None,\n", 142 | " 'LN' : None,\n", 143 | " 'LRSR' : None,\n", 144 | " 'MCM' : None,\n", 145 | " 'front-loading': { # Front-loading parameters (will be set for each simulation after the SOC run)\n", 146 | " 'max_interventions': None,\n", 147 | " 'max_lambda': None\n", 148 | " }\n", 149 | "}\n" 150 | ] 151 | }, 152 | { 153 | "cell_type": "markdown", 154 | "metadata": {}, 155 | "source": [ 156 | "---\n", 157 | "\n", 158 | "## 2. Run calibration\n", 159 | "\n" 160 | ] 161 | }, 162 | { 163 | "cell_type": "markdown", 164 | "metadata": {}, 165 | "source": [ 166 | "Define the range of $q_x$ to experiment with." 167 | ] 168 | }, 169 | { 170 | "cell_type": "code", 171 | "execution_count": 7, 172 | "metadata": {}, 173 | "outputs": [], 174 | "source": [ 175 | "q_x_range = [1.0, 10.0, 25.0, 50.0, 100.0, 250.0, 500.0, 750.0, 1000.0]" 176 | ] 177 | }, 178 | { 179 | "cell_type": "markdown", 180 | "metadata": {}, 181 | "source": [ 182 | "Define the worker function to run several simulation on a given network" 183 | ] 184 | }, 185 | { 186 | "cell_type": "code", 187 | "execution_count": 8, 188 | "metadata": { 189 | "scrolled": true 190 | }, 191 | "outputs": [], 192 | "source": [ 193 | "def worker(policy, policy_params, n_sims):\n", 194 | " graph = make_ebola_network(n_nodes=n_nodes, p_in=p_in, p_out=p_out)\n", 195 | " print(f'graph: {graph.number_of_nodes()} nodes, {graph.number_of_edges()} edges')\n", 196 | "\n", 197 | " init_event_list = sample_seeds(graph, delta=delta, method='data', \n", 198 | " max_date=start_day_str, verbose=False)\n", 199 | "\n", 200 | " res_dict = {'max_u': list(), 'n_tre': list()}\n", 201 | "\n", 202 | " for sim_idx in range(n_sims):\n", 203 | " print(f\"\\rSim {sim_idx+1}/{n_sims}\", end=\"\")\n", 204 | " \n", 205 | " sir_obj = SimulationSIR(graph, beta=beta, delta=delta, gamma=gamma, rho=rho, verbose=True)\n", 206 | " sir_obj.launch_epidemic(\n", 207 | " init_event_list=init_event_list,\n", 208 | " max_time=max_days, \n", 209 | " policy=policy,\n", 210 | " policy_dict=policy_params\n", 211 | " )\n", 212 | " \n", 213 | " res_dict['max_u'].append(sir_obj.max_total_control_intensity)\n", 214 | " res_dict['n_tre'].append(sir_obj.is_tre.sum())\n", 215 | " \n", 216 | " return res_dict" 217 | ] 218 | }, 219 | { 220 | "cell_type": "markdown", 221 | "metadata": {}, 222 | "source": [ 223 | "Run the first simulations with SOC policy." 224 | ] 225 | }, 226 | { 227 | "cell_type": "code", 228 | "execution_count": 9, 229 | "metadata": {}, 230 | "outputs": [ 231 | { 232 | "name": "stdout", 233 | "output_type": "stream", 234 | "text": [ 235 | "=== q_x 1 / 9\n", 236 | "--- Network 1 / 5\n", 237 | "graph: 7456 nodes, 15962 edges\n", 238 | "31.96 days elapsed | 6726 sus., 423 inf., 307 rec., 20 tre (4.73% of inf) | max_u 5.08e+00" 239 | ] 240 | }, 241 | { 242 | "ename": "KeyboardInterrupt", 243 | "evalue": "", 244 | "output_type": "error", 245 | "traceback": [ 246 | "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", 247 | "\u001b[0;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)", 248 | "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[1;32m 13\u001b[0m \u001b[0mpolicy_params\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m'q_x'\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mq_x\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 14\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 15\u001b[0;31m \u001b[0md\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mworker\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mpolicy\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m'SOC'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mpolicy_params\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mpolicy_params\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mn_sims\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mnum_sims\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 16\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 17\u001b[0m res_list.append({\n", 249 | "\u001b[0;32m\u001b[0m in \u001b[0;36mworker\u001b[0;34m(policy, policy_params, n_sims)\u001b[0m\n\u001b[1;32m 16\u001b[0m \u001b[0mmax_time\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mmax_days\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 17\u001b[0m \u001b[0mpolicy\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mpolicy\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 18\u001b[0;31m \u001b[0mpolicy_dict\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mpolicy_params\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 19\u001b[0m )\n\u001b[1;32m 20\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", 250 | "\u001b[0;32m~/Workspace/EPFL/research/ongoing/disease-control/disease-control/lib/dynamics.py\u001b[0m in \u001b[0;36mlaunch_epidemic\u001b[0;34m(self, init_event_list, max_time, policy, policy_dict, stop_criteria)\u001b[0m\n\u001b[1;32m 832\u001b[0m \u001b[0mcontrolled_nodes\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mwhere\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mis_inf\u001b[0m \u001b[0;34m*\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0;36m1\u001b[0m \u001b[0;34m-\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mis_rec\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m*\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0;36m1\u001b[0m \u001b[0;34m-\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mis_tre\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 833\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpolicy\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0;34m'SOC'\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 834\u001b[0;31m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_update_LP_sol\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 835\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mu_idx\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mcontrolled_nodes\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 836\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_control\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0midx_to_node\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mu_idx\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtime\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mpolicy\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpolicy\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", 251 | "\u001b[0;32m~/Workspace/EPFL/research/ongoing/disease-control/disease-control/lib/dynamics.py\u001b[0m in \u001b[0;36m_update_LP_sol\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 752\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 753\u001b[0m \u001b[0mA_dense\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mA\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtoarray\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 754\u001b[0;31m \u001b[0mres\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0msolve_lp\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mc\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mA_dense\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mb\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 755\u001b[0m \u001b[0md_S\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mres\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mlen_I\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 756\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", 252 | "\u001b[0;32m~/miniconda3/envs/disease-ctrl/lib/python3.6/site-packages/lpsolvers/__init__.py\u001b[0m in \u001b[0;36msolve_lp\u001b[0;34m(c, G, h, A, b, solver)\u001b[0m\n\u001b[1;32m 80\u001b[0m \"\"\"\n\u001b[1;32m 81\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0msolver\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0;34m'cvxopt'\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 82\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mcvxopt_solve_lp\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mc\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mG\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mh\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mA\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mb\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 83\u001b[0m \u001b[0;32melif\u001b[0m \u001b[0msolver\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0;34m'cdd'\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 84\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mcdd_solve_lp\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mc\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mG\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mh\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mA\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mb\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", 253 | "\u001b[0;32m~/miniconda3/envs/disease-ctrl/lib/python3.6/site-packages/lpsolvers/cvxopt_.py\u001b[0m in \u001b[0;36mcvxopt_solve_lp\u001b[0;34m(c, G, h, A, b, solver)\u001b[0m\n\u001b[1;32m 90\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mA\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 91\u001b[0m \u001b[0margs\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mextend\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mcvxopt_matrix\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mA\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcvxopt_matrix\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mb\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 92\u001b[0;31m \u001b[0msol\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mlp\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0msolver\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0msolver\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 93\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0;34m'optimal'\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0;32min\u001b[0m \u001b[0msol\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m'status'\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 94\u001b[0m \u001b[0;32mraise\u001b[0m \u001b[0mValueError\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"LP optimum not found: %s\"\u001b[0m \u001b[0;34m%\u001b[0m \u001b[0msol\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m'status'\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", 254 | "\u001b[0;32m~/miniconda3/envs/disease-ctrl/lib/python3.6/site-packages/cvxopt/coneprog.py\u001b[0m in \u001b[0;36mlp\u001b[0;34m(c, G, h, A, b, kktsolver, solver, primalstart, dualstart, **kwargs)\u001b[0m\n\u001b[1;32m 2811\u001b[0m \u001b[0mopts\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0moptions\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mget\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m'glpk'\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;32mNone\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 2812\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mopts\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 2813\u001b[0;31m \u001b[0mstatus\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mx\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mz\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0my\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mglpk\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mlp\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mc\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mG\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mh\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mA\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mb\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0moptions\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mopts\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 2814\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 2815\u001b[0m \u001b[0mstatus\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mx\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mz\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0my\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mglpk\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mlp\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mc\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mG\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mh\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mA\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mb\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", 255 | "\u001b[0;31mKeyboardInterrupt\u001b[0m: " 256 | ] 257 | } 258 | ], 259 | "source": [ 260 | "res_list = list()\n", 261 | "\n", 262 | "num_nets = 5\n", 263 | "num_sims = 5\n", 264 | "\n", 265 | "for i, q_x in enumerate(q_x_range):\n", 266 | " print(f'=== q_x {i+1} / {len(q_x_range)}')\n", 267 | " \n", 268 | " for net_idx in range(num_nets):\n", 269 | " print(f'--- Network {net_idx+1} / {num_nets}')\n", 270 | "\n", 271 | " policy_params = copy.deepcopy(DEFAULT_POLICY_PARAMS)\n", 272 | " policy_params['q_x'] = q_x\n", 273 | "\n", 274 | " d = worker(policy='SOC', policy_params=policy_params, n_sims=num_sims)\n", 275 | " \n", 276 | " res_list.append({\n", 277 | " 'policy': 'SOC',\n", 278 | " 'q_x': q_x,\n", 279 | " 'net_idx': net_idx,\n", 280 | " 'max_u': d['max_u'],\n", 281 | " 'n_tre': d['n_tre']\n", 282 | " })\n", 283 | " \n", 284 | "res_df = pd.DataFrame(res_list)\n", 285 | "res_df.to_csv('baseline-calibration-soc.csv', index=False)" 286 | ] 287 | }, 288 | { 289 | "cell_type": "code", 290 | "execution_count": null, 291 | "metadata": {}, 292 | "outputs": [], 293 | "source": [] 294 | }, 295 | { 296 | "cell_type": "code", 297 | "execution_count": 113, 298 | "metadata": {}, 299 | "outputs": [ 300 | { 301 | "name": "stdout", 302 | "output_type": "stream", 303 | "text": [ 304 | "Epidemic stopped after 31.00 days | 657 sus., 48 inf., 67 rec., 20 tre (41.67% of inf) | I(q): 222 R(q): 48 T(q): 86 |q|: 354\n", 305 | "Epidemic stopped after 31.00 days | 708 sus., 32 inf., 32 rec., 7 tre (21.88% of inf) | I(q): 120 R(q): 33 T(q): 52 |q|: 203\n", 306 | "Epidemic stopped after 31.00 days | 762 sus., 1 inf., 9 rec., 5 tre (500.00% of inf) | I(q): 20 R(q): 1 T(q): 1 |q|: 20\n", 307 | "Epidemic stopped after 31.00 days | 735 sus., 10 inf., 27 rec., 4 tre (40.00% of inf) | I(q): 85 R(q): 11 T(q): 27 |q|: 121\n", 308 | "Epidemic stopped after 31.00 days | 667 sus., 34 inf., 71 rec., 25 tre (73.53% of inf) | I(q): 182 R(q): 35 T(q): 66 |q|: 281\n", 309 | "Epidemic stopped after 31.00 days | 736 sus., 11 inf., 25 rec., 3 tre (27.27% of inf) | I(q): 90 R(q): 12 T(q): 29 |q|: 129\n", 310 | "Epidemic stopped after 31.00 days | 519 sus., 94 inf., 159 rec., 42 tre (44.68% of inf) | I(q): 377 R(q): 95 T(q): 175 |q|: 645\n", 311 | "Epidemic stopped after 31.00 days | 699 sus., 20 inf., 53 rec., 10 tre (50.00% of inf) | I(q): 116 R(q): 21 T(q): 51 |q|: 186\n", 312 | "Epidemic stopped after 28.67 days | 765 sus., 0 inf., 7 rec., 3 tre (nan% of inf) | I(q): 17 R(q): 1 T(q): 2 |q|: 18\n", 313 | "Epidemic stopped after 31.00 days | 672 sus., 44 inf., 56 rec., 13 tre (29.55% of inf) | I(q): 201 R(q): 45 T(q): 75 |q|: 319\n" 314 | ] 315 | }, 316 | { 317 | "data": { 318 | "text/plain": [ 319 | "{'max_u': [3.4499999999999997,\n", 320 | " 1.92,\n", 321 | " 0.3,\n", 322 | " 1.11,\n", 323 | " 3.1500000000000004,\n", 324 | " 1.08,\n", 325 | " 7.590000000000001,\n", 326 | " 2.19,\n", 327 | " 0.21,\n", 328 | " 2.9999999999999996],\n", 329 | " 'n_tre': [20, 7, 5, 4, 25, 3, 42, 10, 3, 13]}" 330 | ] 331 | }, 332 | "execution_count": 113, 333 | "metadata": {}, 334 | "output_type": "execute_result" 335 | } 336 | ], 337 | "source": [ 338 | "policy_params = copy.deepcopy(DEFAULT_POLICY_PARAMS)\n", 339 | "policy_params['TR'] = 0.03\n", 340 | "\n", 341 | "res_dict_baseline = worker(policy='TR', policy_params=policy_params, n_sims=10)\n", 342 | "res_dict_baseline" 343 | ] 344 | }, 345 | { 346 | "cell_type": "code", 347 | "execution_count": null, 348 | "metadata": {}, 349 | "outputs": [], 350 | "source": [] 351 | } 352 | ], 353 | "metadata": { 354 | "kernelspec": { 355 | "display_name": "Python 3", 356 | "language": "python", 357 | "name": "python3" 358 | }, 359 | "language_info": { 360 | "codemirror_mode": { 361 | "name": "ipython", 362 | "version": 3 363 | }, 364 | "file_extension": ".py", 365 | "mimetype": "text/x-python", 366 | "name": "python", 367 | "nbconvert_exporter": "python", 368 | "pygments_lexer": "ipython3", 369 | "version": "3.6.8" 370 | } 371 | }, 372 | "nbformat": 4, 373 | "nbformat_minor": 2 374 | } 375 | -------------------------------------------------------------------------------- /lib/ogata/analysis.py: -------------------------------------------------------------------------------- 1 | 2 | import numpy as np 3 | import matplotlib 4 | import matplotlib.pyplot as plt 5 | from tqdm import tqdm, tqdm_notebook 6 | import os 7 | 8 | from helpers import HelperFunc 9 | 10 | matplotlib.rcParams.update({ 11 | "figure.autolayout": False, 12 | "figure.figsize": (8, 6), 13 | "figure.dpi": 72, 14 | "axes.linewidth": 0.8, 15 | "xtick.major.width": 0.8, 16 | "xtick.minor.width": 0.8, 17 | "ytick.major.width": 0.8, 18 | "ytick.minor.width": 0.8, 19 | "text.usetex": True, 20 | "font.size": 16, 21 | "axes.titlesize": 16, 22 | "axes.labelsize": 16, 23 | "legend.fontsize": 14, 24 | "legend.frameon": True, 25 | "xtick.labelsize": 14, 26 | "ytick.labelsize": 14, 27 | "lines.linewidth": 2.0, 28 | "lines.markersize": 4, 29 | "grid.linewidth": 0.4, 30 | }) 31 | 32 | 33 | class Evaluation: 34 | """ 35 | Class that analyzes results of dynamical system simulations 36 | """ 37 | 38 | def __init__(self, data, plot_dir, description): 39 | self.data = data 40 | self.dirname = plot_dir 41 | self.descr = description 42 | 43 | self.colors = 'rggbbkkym' 44 | self.linestyles = ['-', '-', ':', '-', ':', '-', ':', '-', '-'] 45 | 46 | # Create directory for plots 47 | if not os.path.exists(self.dirname): 48 | os.makedirs(self.dirname) 49 | 50 | ''' *** Helper Functions *** ''' 51 | 52 | def __getTextBoxString(self): 53 | """ 54 | Create parameter description string for textbox. 55 | """ 56 | dict = self.data[0][0]['info'] 57 | s_beta = r'$\beta$: ' + str(dict['beta']) + ', ' 58 | s_delta = r'$\delta$: ' + str(dict['delta']) + ', ' 59 | s_rho = r'$\rho$: ' + str(dict['rho']) + ', ' 60 | s_gamma = r'$\gamma$: ' + str(dict['gamma']) + ', ' 61 | s_eta = r'$\eta$: ' + str(dict['eta']) 62 | s_Qlam = r'Q$_{\lambda}$: ' + \ 63 | str(np.mean(dict['Qlam'])) + ', ' 64 | s_Qx = 'Q$_{X}$: ' + str(np.mean(dict['Qx'])) 65 | s_sims = 'no. of simulations: ' + str(len(self.data[0])) 66 | 67 | s = s_beta + s_gamma + s_delta + s_rho + s_eta + '\n' \ 68 | + s_Qlam + s_Qx + '\n' \ 69 | + s_sims 70 | return s 71 | 72 | def __integrateF(self, f_of_t, eta): 73 | """ 74 | Compute the integral from 0 to T of e^(eta * t) * f_of_t * dt for a 75 | given trial assuming f_of_t is tuple returned by 76 | `HelperFunc.step_sps_values_over_time`. 77 | """ 78 | t, f = f_of_t 79 | 80 | # compute the integral by summing integrals of constant intervals 81 | # given by f_of_t 82 | val = 0.0 83 | indices = [(i, i + 1) for i in [2 * j for j in range(round(len(t) / 2))]] 84 | for i, j in indices: 85 | const = f[i] 86 | a, b = t[i], t[j] 87 | if eta == 0.0: 88 | # int_a^b const * dt = const * (b - a) 89 | val += const * (b - a) 90 | else: 91 | # int_a^b exp(- eta * t) * const * dt = const / eta * (exp(- a * eta) - exp(- b * eta)) 92 | val += const / eta * (np.exp(- a * eta) - np.exp(- b * eta)) 93 | return val 94 | 95 | def computeIntX(self, trial, custom_eta=None, weight_by_Qx=True): 96 | """ 97 | Compute the integral from 0 to T of (Qx * X) dt for a given trial. 98 | """ 99 | hf = HelperFunc() 100 | if custom_eta is None: 101 | eta = trial['info']['eta'] 102 | else: 103 | eta = custom_eta 104 | X_, Qx = trial['X'], trial['info']['Qx'] 105 | t, X = hf.step_sps_values_over_time(X_, summed=False) 106 | if X: 107 | if weight_by_Qx: 108 | f_of_t = t, np.dot(Qx, np.array(X).T) 109 | else: 110 | f_of_t = t, np.dot(np.ones(Qx.shape), np.array(X).T) 111 | else: 112 | f_of_t = t, 0 113 | return self.__integrateF(f_of_t, eta) 114 | 115 | def __computeIntLambda(self, trial, custom_eta=None): 116 | """ 117 | Compute integral from 0 to T of (0.5 * Qlam * u^2) dt 118 | for a given trial. 119 | """ 120 | hf = HelperFunc() 121 | if custom_eta is None: 122 | eta = trial['info']['eta'] 123 | else: 124 | eta = custom_eta 125 | u_, Qlam = trial['u'], trial['info']['Qlam'] 126 | t, u = hf.step_sps_values_over_time(u_, summed=False) 127 | if u: 128 | f_of_t = t, 0.5 * np.dot(Qlam, np.square(np.array(u)).T) 129 | else: 130 | f_of_t = t, 0 131 | return self.__integrateF(f_of_t, eta) 132 | 133 | def _computeIntH(self, trial, custom_eta=None): 134 | """ 135 | Compute integral from 0 to T of |H|_1 dt for a given trial. 136 | """ 137 | hf = HelperFunc() 138 | if custom_eta is None: 139 | eta = trial['info']['eta'] 140 | else: 141 | eta = custom_eta 142 | H_ = trial['H'] 143 | t, H = hf.step_sps_values_over_time(H_, summed=False) 144 | 145 | if H: 146 | f_of_t = t, np.dot(np.ones(trial['info']['N']), np.array(H).T) 147 | else: 148 | f_of_t = t, 0 149 | 150 | return self.__integrateF(f_of_t, eta) 151 | 152 | ''' *** ANALYSIS *** ''' 153 | 154 | def present_discounted_loss(self, plot=False, save=False): 155 | """ 156 | Plot PDV of total incurred cost 157 | (i.e. the infinite horizon loss function). 158 | """ 159 | # Compute integral for every heuristic 160 | print(("Computing present discounted loss integral " 161 | "for every heuristic...")) 162 | pdvs_by_heuristic = [[self.computeIntX(trial) + self.__computeIntLambda(trial) 163 | for trial in tqdm(heuristic)] for heuristic in self.data] 164 | means, stddevs = [np.mean(pdvs) for pdvs in pdvs_by_heuristic], [np.std(pdvs) for pdvs in pdvs_by_heuristic] 165 | print("...done.") 166 | 167 | # Plotting functionality 168 | if plot: 169 | fig = plt.figure(figsize=(10, 6), facecolor='white') 170 | ax = fig.add_subplot(111, frameon=False) 171 | 172 | x = np.arange(len(means)) 173 | width = 0.2 174 | ax.bar(x + width / 2, means, yerr=stddevs, 175 | width=width, align='center', color='rgbkymcgbkymc') 176 | ax.set_xticks(x + width / 2) 177 | ax.set_xlabel('Policies') 178 | ax.set_xticklabels(self.descr) 179 | ax.set_ylabel('Present discounted loss') 180 | 181 | # text box 182 | box = True 183 | if box: 184 | s = self.__getTextBoxString() 185 | _, upper = ax.get_ylim() 186 | plt.text(0.0, 0.8 * upper, s, size=12, 187 | va="baseline", ha="left", multialignment="left", 188 | bbox=dict(fc="none")) 189 | 190 | plt.title(("Cumulative present discounted loss " 191 | "(=objective function) for all heuristics")) 192 | if save: 193 | plt.savefig(os.path.join(self.dirname, 'PDV_plot.png'), 194 | format='png', frameon=False) 195 | plt.close() 196 | else: 197 | plt.show() 198 | print("\nPresent discounted loss (Mean, StdDev) \n") 199 | for j in range(len(self.data)): 200 | print(self.descr[j] + ':\t' + str(round(means[j], 3)) + '\t' + str(round(stddevs[j], 3)) ) 201 | return 0 202 | 203 | def infections_and_interventions_complete(self, size_tup=(15, 10), save=False): 204 | """ 205 | Summarizes simulations in 3 plots 206 | - Infection coverage (Int X(t) dt) - Total discrete interventions (Sum N(T)) 207 | - Infection coverage (Int X(t) dt) - Treatment coverage (Int H(t) dt) 208 | - Infection events (Sum Y(T)) - Total discrete interventions (Sum N(T)) 209 | """ 210 | 211 | # Compute statistics for every heuristic 212 | hf = HelperFunc() 213 | intX_by_heuristic = [[self.computeIntX(trial, custom_eta=0.0, weight_by_Qx=False) 214 | for trial in heuristic] for heuristic in tqdm_notebook(self.data)] 215 | 216 | intX_m = np.array([np.mean(h) for h in intX_by_heuristic]) 217 | intX_s = np.array([np.std(h) for h in intX_by_heuristic]) 218 | 219 | intH_by_heuristic = [[self._computeIntH(trial, custom_eta=0.0) 220 | for trial in heuristic] for heuristic in tqdm_notebook(self.data)] 221 | 222 | intH_m = np.array([np.mean(h) for h in intH_by_heuristic]) 223 | intH_s = np.array([np.std(h) for h in intH_by_heuristic]) 224 | 225 | Y_by_heuristic = [[hf.sps_values(trial['Y'], trial['info']['ttotal'], summed=True) 226 | for trial in heuristic] for heuristic in tqdm_notebook(self.data)] 227 | 228 | Y_m = np.array([np.mean(h) for h in Y_by_heuristic]) 229 | Y_s = np.array([np.std(h) for h in Y_by_heuristic]) 230 | 231 | N_by_heuristic = [[hf.sps_values(trial['Nc'], trial['info']['ttotal'], summed=True) 232 | for trial in heuristic] for heuristic in tqdm_notebook(self.data)] 233 | 234 | N_m = np.array([np.mean(h) for h in N_by_heuristic]) 235 | N_s = np.array([np.std(h) for h in N_by_heuristic]) 236 | 237 | x = np.arange(len(intX_m)) 238 | n = 50 # trials per simulation 239 | 240 | # Plotting functionality 241 | plt.rc('text', usetex=True) 242 | plt.rc('font', family='serif') 243 | 244 | fig = plt.figure(figsize=(12, 8), facecolor='white') 245 | 246 | # 1 - Infection coverage (Int X(t) dt) - Total discrete interventions (Sum N(T)) 247 | ax = fig.add_subplot(2, 1, 1, frameon=False) 248 | 249 | width = 0.2 250 | ax.bar(x, intX_m, yerr=intX_s / np.sqrt(n), 251 | width=width, align='center', color='rgbkymcgbkymc') 252 | ax.set_xticks(x + width / 2) 253 | ax.set_xlabel(r'Policies') 254 | ax.set_xticklabels(self.descr) 255 | ax.set_ylabel(r'Infection coverage $\int_{t_0}^{t_f} \mathbf{X}(t) dt$') 256 | 257 | ax2 = ax.twinx() 258 | ax2.patch.set_visible(False) 259 | ax2.bar(x + width, N_m, yerr=N_s / np.sqrt(n), 260 | width=width, align='center', color='rgbkymcgbkymc', alpha=0.5) 261 | ax2.set_ylabel(r'Interventions $\sum_{i=1}^{|nodes|} \mathbf{N}_i(t_f)$') 262 | 263 | plt.title(r"Infection coverage and discrete interventions") 264 | 265 | # 2 - Infection coverage (Int X(t) dt) - Treatment coverage (Int H(t) dt) 266 | ax = fig.add_subplot(2, 2, 3, frameon=False) 267 | 268 | width = 0.2 269 | ax.bar(x, intX_m, yerr=intX_s / np.sqrt(n), 270 | width=width, align='center', color='rgbkymcgbkymc') 271 | ax.set_xticks(x + width / 2) 272 | ax.set_xlabel(r'Policies') 273 | ax.set_xticklabels(['']*len(self.descr)) 274 | ax.set_ylabel(r'Infection coverage $\int_{t_0}^{t_f} \mathbf{X}(t) dt$') 275 | 276 | ax2 = ax.twinx() 277 | ax2.patch.set_visible(False) 278 | ax2.bar(x + width, intH_m, yerr=intH_s / np.sqrt(n), 279 | width=width, align='center', color='rgbkymcgbkymc', alpha=0.5) 280 | ax2.set_ylabel(r'Treatment coverage $\int_{t_0}^{t_f} \mathbf{H}(t) dt$') 281 | 282 | plt.title(r"Infection coverage and treatment coverage") 283 | 284 | # 3 - Infection events (Sum Y(T)) - Total discrete interventions (Sum N(T)) 285 | ax = fig.add_subplot(2, 2, 4, frameon=False) 286 | 287 | width = 0.2 288 | ax.bar(x, intX_m, yerr=intX_s / np.sqrt(n), 289 | width=width, align='center', color='rgbkymcgbkymc') 290 | ax.set_xticks(x + width / 2) 291 | ax.set_xlabel(r'Policies') 292 | ax.set_xticklabels(['']*len(self.descr)) 293 | ax.set_ylabel(r'Infections $\sum_{i=1}^{|nodes|} \mathbf{Y}_i(t_f)$') 294 | 295 | ax2 = ax.twinx() 296 | ax2.patch.set_visible(False) 297 | ax2.bar(x + width, N_m, yerr=N_s / np.sqrt(n), 298 | width=width, align='center', color='rgbkymcgbkymc', alpha=0.5) 299 | ax2.set_ylabel(r'Interventions $\sum_{i=1}^{|nodes|} \mathbf{N}_i(t_f)$') 300 | 301 | plt.title(r"Infection events and discrete interventions") 302 | plt.tight_layout() 303 | 304 | if save: 305 | fig_filename = os.path.join(self.dirname, "infections_and_interventions_complete" + '.pdf') 306 | plt.savefig(fig_filename, format='pdf', frameon=False, dpi=300) 307 | plt.close() 308 | else: 309 | plt.show() 310 | 311 | return ((intX_m, intX_s), (N_m, N_s), (intH_m, intH_s), (Y_m, Y_s)) 312 | 313 | def infection_cost_AND_intervention_effort(self, plot=False, save=False): 314 | """ 315 | Plots the TOTAL infection cost (Qx.X) and the TOTAL intervention 316 | effort (Qlam.u^2). 317 | """ 318 | # Compute total infection cost and total time under treatment 319 | # for every heuristic 320 | print(("Computing total infection cost and total time under treatment " 321 | "for every heuristic...")) 322 | infection_cost_by_heuristic = [[self.computeIntX(trial, custom_eta=0.0) for trial in heuristic] for heuristic in self.data] 323 | treatment_time_by_heuristic = [[self.__computeIntLambda(trial, custom_eta=0.0) for trial in heuristic] for heuristic in self.data] 324 | print("...done.") 325 | 326 | means_infection, stddevs_infection = [np.mean(infections) for infections in infection_cost_by_heuristic], \ 327 | [np.std(infections) for infections in infection_cost_by_heuristic] 328 | means_treatment, stddevs_treatment = [np.mean(treatments) for treatments in treatment_time_by_heuristic], \ 329 | [np.std(treatments) for treatments in treatment_time_by_heuristic] 330 | 331 | # Plotting functionality 332 | if plot: 333 | fig = plt.figure(figsize=(10, 6), facecolor='white') 334 | ax = fig.add_subplot(111, frameon=False) 335 | x = np.arange(len(means_infection)) 336 | width = 0.2 337 | ax.bar(x, means_infection, yerr=stddevs_infection, 338 | width=width, align='center', color='rgbkymcgbkymc') 339 | ax.set_xticks(x + width / 2) 340 | ax.set_xlabel('Policies') 341 | ax.set_xticklabels(self.descr) 342 | ax.set_ylabel('Infection cost incurred [Left]') 343 | 344 | ax2 = ax.twinx() 345 | ax2.patch.set_visible(False) 346 | ax2.bar(x + width, means_treatment, yerr=stddevs_treatment, 347 | width=width, align='center', color='rgbkymcgbkymc', alpha=0.5) 348 | ax2.set_ylabel('Treatment effort [Right]') 349 | 350 | plt.title(("Total infection cost Int(Qx.X) [Left] & " 351 | "Total intervention effort Int(Qlam.Lambda^2) " 352 | "[Right] for all heuristics")) 353 | 354 | if save: 355 | plt.savefig(os.path.join( 356 | self.dirname, 'infection_cost_AND_intervention_effort.png'), 357 | format='png', frameon=False) 358 | plt.close() 359 | else: 360 | plt.show() 361 | 362 | print(("\nTotal infection cost and total intervention effort " 363 | "(Mean, StdDev) \n")) 364 | for j in range(len(self.data)): 365 | print(self.descr[j] + ':\t' + str(round(means_infection[j], 3)) + 366 | '\t' + str(round(stddevs_infection[j], 3)) + 367 | '\t ---- \t' + str(round(means_treatment[j], 3)) + 368 | '\t' + str(round(stddevs_treatment[j], 3))) 369 | return 0 370 | 371 | def summarize_interventions_and_intensities(self): 372 | """ 373 | Return total number of interventions & peak and average treatment 374 | intensities for every heuristic. 375 | """ 376 | 377 | hf = HelperFunc() 378 | 379 | # Intensities 380 | max_intensities = np.zeros((len(self.data), len(self.data[0])), dtype=object) 381 | for i, heuristic in enumerate(tqdm(self.data)): 382 | for j, trial in enumerate(heuristic): 383 | all_arrivals = hf.all_arrivals(trial['u']) 384 | max_intensities[i, j] = np.zeros(len(all_arrivals)) 385 | for k, t in enumerate(all_arrivals): 386 | max_intensities[i, j][k] = np.max( 387 | hf.sps_values(trial['u'], t, summed=False)) 388 | 389 | max_per_trial = [[np.max(trial) for trial in heuristic] 390 | for heuristic in tqdm(max_intensities)] 391 | max_per_heuristic = [np.max(heuristic) for heuristic in max_per_trial] 392 | 393 | print(max_per_trial) 394 | print(max_per_heuristic) 395 | 396 | # Treatments 397 | 398 | # treatments_by_heuristic = [[hf.sps_values(trial['Nc'], trial['info']['ttotal'], summed=True) 399 | # for trial in heuristic] for heuristic in self.data] 400 | 401 | # means_treatment, stddevs_treatment = \ 402 | # [np.mean(treatments) for treatments in treatments_by_heuristic], \ 403 | # [np.std(treatments) for treatments in treatments_by_heuristic] 404 | 405 | return 0 # TODO Change back and delete this 406 | 407 | def simulation_plot(self, process, figsize=(8, 6), granularity=0.1, 408 | filename='simulation_summary', draw_box=False, 409 | save=False): 410 | """ 411 | Plot a summary of the simulation. 412 | 413 | Parameters 414 | ---------- 415 | process : str 416 | Process to plot (`X`, `H`, `Y`, `W`, `Nc`, `u`) 417 | figsize : tuple, optional 418 | Figure size 419 | granularity : float, optional 420 | Time (x-axis) granularity of the plot 421 | filename : str, optional 422 | Filename to save the plot 423 | save : bool, optional 424 | If True, the plot is saved to `filename` 425 | draw_box : bool, optional 426 | If True, draw a text box with simulation parameters 427 | """ 428 | 429 | print(f"Building simulation figure for process {process:s}...") 430 | 431 | fig, ax = plt.subplots(1, 1, figsize=figsize) 432 | 433 | hf = HelperFunc() 434 | 435 | for i, heuristic in enumerate(self.data): 436 | # Number of trials 437 | n_trials = len(heuristic) 438 | # Simulation end time 439 | ttotal = heuristic[0]['info']['ttotal'] 440 | # Linearly spaced array of time 441 | tspace = np.arange(0.0, ttotal, granularity) 442 | # Extract the values of the stochastic processes at all times 443 | # for each trial 444 | values = np.zeros((n_trials, len(tspace))) 445 | for j, trial in enumerate(tqdm_notebook(heuristic)): 446 | for k, t in enumerate(tspace): 447 | values[j, k] = hf.sps_values(trial[process], t, summed=True) 448 | # Compute mean and std over trials 449 | mean_X_t = np.mean(values, axis=0) 450 | stddev_X_t = np.std(values, axis=0) 451 | # Plot the mean +/- std 452 | ax.plot(tspace, mean_X_t, color=self.colors[i], linestyle=self.linestyles[i]) 453 | ax.fill_between(tspace, mean_X_t - stddev_X_t, mean_X_t + stddev_X_t, 454 | alpha=0.3, edgecolor=self.colors[i], facecolor=self.colors[i], 455 | linewidth=0) 456 | ax.set_xlim([0, ttotal]) 457 | ax.set_xlabel("Elapsed time") 458 | # ax.set_ylim([0, heuristic[0]['info']['N']]) 459 | ax.set_ylim(bottom=0) 460 | if process == 'X': 461 | ax.set_ylabel("Number of infected nodes") 462 | elif process == 'H': 463 | ax.set_ylabel("Number of treated nodes") 464 | else: 465 | ax.set_ylabel("Number of nodes") 466 | 467 | # Text box 468 | if draw_box: 469 | s = self.__getTextBoxString() 470 | _, upper = ax.get_ylim() 471 | plt.text(0.0, upper, s, size=12, 472 | va="baseline", ha="left", multialignment="left", 473 | bbox=dict(fc="none")) 474 | # Legend 475 | legend = [] 476 | for policy in self.descr: 477 | legend += [policy] 478 | ax.legend(legend) 479 | plt.draw() 480 | 481 | if save: 482 | fig_filename = os.path.join(self.dirname, filename + '.pdf') 483 | plt.savefig(fig_filename, format='pdf', frameon=False, dpi=300) 484 | plt.close() 485 | else: 486 | plt.show() 487 | 488 | 489 | class MultipleEvaluations: 490 | """ 491 | Class that plots results of multiple dynamical system simulations. 492 | """ 493 | 494 | def __init__(self, multi_summary, policy_list, n_trials, save_dir): 495 | """ 496 | 497 | Arguments: 498 | ---------- 499 | multi_summary : dict 500 | Data for the plot formatted as follows: 501 | { 502 | 'Qs': { 503 | 'expname_1': `qs_1`, (`qs_1` is a float corresponding for parameter qs) 504 | 'expname_2': `qs_2`, 505 | ... 506 | }, 507 | 'infections_and_interventions': { 508 | 'expname_1': ( 509 | ( 510 | `intX_m`, (mean of integral of process X over the observation period) 511 | `intX_s` (std of integral of process X over the observation period) 512 | ), 513 | ( 514 | `N_m`, (mean of process N over the observation period) 515 | `N_s` (std of process N over the observation period) 516 | ), 517 | ( 518 | `intH_m`, (mean of integral of process H over the observation period) 519 | `intH_s` (std of integral of process H over the observation period) 520 | ), 521 | ( 522 | `Y_m`, (mean of process Y over the observation period) 523 | `Y_s` (mean of process Y over the observation period) 524 | ) 525 | ), 526 | ... 527 | () 528 | 529 | } 530 | } 531 | """ 532 | self.multi_summary = multi_summary 533 | self.policy_list = policy_list 534 | self.n_trials = n_trials 535 | 536 | self.colors = 'rggbbkkym' 537 | self.linestyles = ['-', '-', ':', '-', ':', '-', ':', '-', '-'] 538 | 539 | # create directory for plots 540 | self.save_dir = save_dir 541 | if not os.path.exists(self.save_dir): 542 | os.makedirs(self.save_dir) 543 | 544 | def compare_infections(self, size_tup=(8, 6), save=True): 545 | """ 546 | Compare infections along Qx axis (assuming Qlambda = const = 1.0). 547 | """ 548 | d = self.multi_summary.get('infections_and_interventions', None) 549 | Qs = self.multi_summary.get('Qs', None) 550 | 551 | if d is None or Qs is None: 552 | raise ValueError('Missing data.') 553 | 554 | keys = np.array(list(Qs.keys())) 555 | n_exps = len(keys) 556 | 557 | # assumes data had same methods tested 558 | infections_axis = {name: np.zeros(n_exps) for name in self.policy_list} 559 | infections_axis_std = {name: np.zeros(n_exps) for name in self.policy_list} 560 | interventions_axis = {name: np.zeros(n_exps) for name in self.policy_list} 561 | interventions_axis_std = {name: np.zeros(n_exps) for name in self.policy_list} 562 | 563 | # Build X-axis 564 | Qx_axis = np.array([Qs[key] for key in keys]) 565 | # Sort by value 566 | sorted_args = np.argsort(Qx_axis) 567 | Qx_axis = Qx_axis[sorted_args] 568 | keys = keys[sorted_args] 569 | 570 | # Build Y-axis 571 | for i, name in enumerate(self.policy_list): 572 | for j, key in enumerate(keys): 573 | # d[key] = ((intX_m, intX_s), (N_m, N_s), (intH_m, intH_s), (Y_m, Y_s)) 574 | infections_axis[name][j] = d[key][0][0][i] 575 | infections_axis_std[name][j] = d[key][0][1][i] / np.sqrt(self.n_trials) # transfrom stddev into std error 576 | interventions_axis[name][j] = d[key][1][0][i] 577 | interventions_axis_std[name][j] = d[key][1][1][i] / np.sqrt(self.n_trials) # transfrom stddev into std error 578 | 579 | 580 | 581 | # Set up figure. 582 | fig, ax = plt.subplots(1, 1, figsize=(12, 8), facecolor='white') 583 | plt.cla() 584 | 585 | hf = HelperFunc() 586 | 587 | legend = [] 588 | max_infected = 0 589 | for ind, name in enumerate(self.policy_list): 590 | legend.append(name) 591 | 592 | # linear axis 593 | ax.plot(Qx_axis, infections_axis[name], color=self.colors[ind], 594 | linestyle=self.linestyles[ind]) 595 | ax.fill_between( 596 | Qx_axis, 597 | infections_axis[name] - interventions_axis_std[name], 598 | infections_axis[name] + interventions_axis_std[name], 599 | alpha=0.3, 600 | edgecolor=self.colors[ind], 601 | facecolor=self.colors[ind], 602 | linewidth=0) 603 | 604 | # ax.errorbar(Qx_axis, infections_axis[name], yerr=interventions_axis_std[name]) 605 | 606 | if max(infections_axis[name]) > max_infected: 607 | max_infected = max(infections_axis[name]) 608 | 609 | ax.set_xlim([0.7, max(Qx_axis)]) 610 | ax.set_xlabel(r'$Q_x$') 611 | ax.set_ylim([0, 1.3 * max_infected]) 612 | ax.set_ylabel(r'Infection coverage $\int_{t_0}^{t_f} \mathbf{X}(t) dt$') 613 | ax.legend(legend) 614 | 615 | # ax.set_xscale("log", nonposx='clip') 616 | 617 | if save: 618 | dpi = 300 619 | # plt.tight_layout() 620 | fig = plt.gcf() # get current figure 621 | fig.set_size_inches(size_tup) # width, height 622 | plt.savefig(os.path.join(self.save_dir, 'infections_fair_comparison.png'), frameon=False, format='png', dpi=dpi) 623 | plt.close() 624 | else: 625 | plt.show() 626 | -------------------------------------------------------------------------------- /notebooks/1-preprocessing/Debug - dynamics_ind.SimulationSIR.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "%load_ext autoreload\n", 10 | "%autoreload 2" 11 | ] 12 | }, 13 | { 14 | "cell_type": "code", 15 | "execution_count": 2, 16 | "metadata": {}, 17 | "outputs": [], 18 | "source": [ 19 | "import sys\n", 20 | "if '..' not in sys.path:\n", 21 | " sys.path.append('..')" 22 | ] 23 | }, 24 | { 25 | "cell_type": "code", 26 | "execution_count": 3, 27 | "metadata": {}, 28 | "outputs": [], 29 | "source": [ 30 | "import random as rd\n", 31 | "import numpy as np\n", 32 | "import networkx as nx\n", 33 | "\n", 34 | "from matplotlib import pyplot as plt\n", 35 | "%matplotlib inline\n", 36 | "\n", 37 | "from dynamics_ind import SimulationSIR, PriorityQueue, sample_seeds" 38 | ] 39 | }, 40 | { 41 | "cell_type": "code", 42 | "execution_count": 17, 43 | "metadata": {}, 44 | "outputs": [], 45 | "source": [ 46 | "n_nodes = 1000\n", 47 | "graph = nx.complete_graph(n=n_nodes)\n", 48 | "\n", 49 | "# graph = nx.from_edgelist([(0, 1), (1,2), (2,3), (3,4), (4,5)], create_using=nx.Graph)\n", 50 | "# n_nodes = graph.number_of_nodes()\n", 51 | "\n", 52 | "A = nx.adjacency_matrix(graph).toarray()\n", 53 | "\n", 54 | "\n", 55 | "beta = 1 / 8.0\n", 56 | "delta = n_nodes / 12.3\n", 57 | "\n", 58 | "gamma = 0.0\n", 59 | "rho = 0.0\n", 60 | "eta = 1.0\n", 61 | "\n", 62 | "q_x = 300.0\n", 63 | "q_lam = 1.0\n", 64 | "\n", 65 | "max_time = 10.0\n", 66 | "\n", 67 | "policy_dict = {\n", 68 | " 'TR': 0.06,\n", 69 | " 'MN': 0.012,\n", 70 | " 'LN': 0.012,\n", 71 | " 'LRSR': 22.807,\n", 72 | " 'MCM': 22.807,\n", 73 | " 'FL_info': {'N': None, 'max_u': None},\n", 74 | "}\n", 75 | "\n", 76 | "policy_name = 'NO'" 77 | ] 78 | }, 79 | { 80 | "cell_type": "code", 81 | "execution_count": 18, 82 | "metadata": {}, 83 | "outputs": [ 84 | { 85 | "name": "stdout", 86 | "output_type": "stream", 87 | "text": [ 88 | "[[(0, 'inf', 0), 0.0], [(0, 'rec', None), 0.020955449007606965]]\n" 89 | ] 90 | } 91 | ], 92 | "source": [ 93 | "init_event_list = [\n", 94 | " [(0, 'inf', 0), 0.0],\n", 95 | " [(0, 'rec', None), rd.expovariate(delta)],\n", 96 | "]\n", 97 | "\n", 98 | "X_init = np.zeros(n_nodes)\n", 99 | "X_init[0] = 1\n", 100 | "\n", 101 | "print(init_event_list)" 102 | ] 103 | }, 104 | { 105 | "cell_type": "markdown", 106 | "metadata": {}, 107 | "source": [ 108 | "---" 109 | ] 110 | }, 111 | { 112 | "cell_type": "code", 113 | "execution_count": 19, 114 | "metadata": { 115 | "scrolled": true 116 | }, 117 | "outputs": [ 118 | { 119 | "name": "stdout", 120 | "output_type": "stream", 121 | "text": [ 122 | "1\n", 123 | "Epidemic stopped after 0.03 days | 998 sus., 0 inf., 2 rec., 0 tre (nan% of inf) | I(q): 1995 R(q): 0 T(q): 0 |q|: 1995\n", 124 | "2\n", 125 | "Epidemic stopped after 0.25 days | 362 sus., 0 inf., 638 rec., 0 tre (nan% of inf) | I(q): 424207 R(q): 0 T(q): 0 |q|: 424207\n", 126 | "3\n", 127 | "Epidemic stopped after 0.03 days | 993 sus., 0 inf., 7 rec., 0 tre (nan% of inf) | I(q): 6961 R(q): 0 T(q): 0 |q|: 6961\n", 128 | "4\n", 129 | "Epidemic stopped after 0.02 days | 998 sus., 0 inf., 2 rec., 0 tre (nan% of inf) | I(q): 1995 R(q): 0 T(q): 0 |q|: 1995\n", 130 | "5\n", 131 | "Epidemic stopped after 0.22 days | 511 sus., 0 inf., 489 rec., 0 tre (nan% of inf) | I(q): 363685 R(q): 0 T(q): 0 |q|: 363685\n", 132 | "6\n", 133 | "Epidemic stopped after 0.25 days | 408 sus., 0 inf., 592 rec., 0 tre (nan% of inf) | I(q): 409317 R(q): 0 T(q): 0 |q|: 409317\n", 134 | "7\n", 135 | "Epidemic stopped after 0.31 days | 510 sus., 0 inf., 490 rec., 0 tre (nan% of inf) | I(q): 361773 R(q): 0 T(q): 0 |q|: 361773\n", 136 | "8\n", 137 | "Epidemic stopped after 0.12 days | 984 sus., 0 inf., 16 rec., 0 tre (nan% of inf) | I(q): 15725 R(q): 0 T(q): 0 |q|: 15725\n", 138 | "9\n", 139 | "Epidemic stopped after 0.08 days | 993 sus., 0 inf., 7 rec., 0 tre (nan% of inf) | I(q): 6934 R(q): 0 T(q): 0 |q|: 6934\n", 140 | "10\n", 141 | "Epidemic stopped after 0.03 days | 995 sus., 0 inf., 5 rec., 0 tre (nan% of inf) | I(q): 4981 R(q): 0 T(q): 0 |q|: 4981\n", 142 | "11\n", 143 | "Epidemic stopped after 0.23 days | 441 sus., 0 inf., 559 rec., 0 tre (nan% of inf) | I(q): 395840 R(q): 0 T(q): 0 |q|: 395840\n", 144 | "12\n", 145 | "Epidemic stopped after 0.04 days | 988 sus., 0 inf., 12 rec., 0 tre (nan% of inf) | I(q): 11888 R(q): 0 T(q): 0 |q|: 11888\n", 146 | "13\n", 147 | "Epidemic stopped after 0.28 days | 320 sus., 0 inf., 680 rec., 0 tre (nan% of inf) | I(q): 438214 R(q): 0 T(q): 0 |q|: 438214\n", 148 | "14\n", 149 | "Epidemic stopped after 0.02 days | 998 sus., 0 inf., 2 rec., 0 tre (nan% of inf) | I(q): 1995 R(q): 0 T(q): 0 |q|: 1995\n", 150 | "15\n", 151 | "Epidemic stopped after 0.22 days | 370 sus., 0 inf., 630 rec., 0 tre (nan% of inf) | I(q): 423958 R(q): 0 T(q): 0 |q|: 423958\n", 152 | "16\n", 153 | "Epidemic stopped after 0.02 days | 998 sus., 0 inf., 2 rec., 0 tre (nan% of inf) | I(q): 1996 R(q): 0 T(q): 0 |q|: 1996\n", 154 | "17\n", 155 | "Epidemic stopped after 0.38 days | 379 sus., 0 inf., 621 rec., 0 tre (nan% of inf) | I(q): 417000 R(q): 0 T(q): 0 |q|: 417000\n", 156 | "18\n", 157 | "Epidemic stopped after 0.02 days | 998 sus., 0 inf., 2 rec., 0 tre (nan% of inf) | I(q): 1993 R(q): 0 T(q): 0 |q|: 1993\n", 158 | "19\n", 159 | "Epidemic stopped after 0.02 days | 999 sus., 0 inf., 1 rec., 0 tre (nan% of inf) | I(q): 999 R(q): 0 T(q): 0 |q|: 999\n", 160 | "20\n", 161 | "Epidemic stopped after 0.03 days | 993 sus., 0 inf., 7 rec., 0 tre (nan% of inf) | I(q): 6952 R(q): 0 T(q): 0 |q|: 6952\n", 162 | "21\n", 163 | "Epidemic stopped after 0.04 days | 996 sus., 0 inf., 4 rec., 0 tre (nan% of inf) | I(q): 3983 R(q): 0 T(q): 0 |q|: 3983\n", 164 | "22\n", 165 | "Epidemic stopped after 0.06 days | 991 sus., 0 inf., 9 rec., 0 tre (nan% of inf) | I(q): 8927 R(q): 0 T(q): 0 |q|: 8927\n", 166 | "23\n", 167 | "Epidemic stopped after 0.26 days | 457 sus., 0 inf., 543 rec., 0 tre (nan% of inf) | I(q): 386871 R(q): 0 T(q): 0 |q|: 386871\n", 168 | "24\n", 169 | "Epidemic stopped after 0.25 days | 429 sus., 0 inf., 571 rec., 0 tre (nan% of inf) | I(q): 401595 R(q): 0 T(q): 0 |q|: 401595\n", 170 | "25\n", 171 | "Epidemic stopped after 0.24 days | 370 sus., 0 inf., 630 rec., 0 tre (nan% of inf) | I(q): 424967 R(q): 0 T(q): 0 |q|: 424967\n", 172 | "26\n", 173 | "Epidemic stopped after 0.27 days | 353 sus., 0 inf., 647 rec., 0 tre (nan% of inf) | I(q): 426785 R(q): 0 T(q): 0 |q|: 426785\n", 174 | "27\n", 175 | "Epidemic stopped after 0.28 days | 499 sus., 0 inf., 501 rec., 0 tre (nan% of inf) | I(q): 366174 R(q): 0 T(q): 0 |q|: 366174\n", 176 | "28\n", 177 | "Epidemic stopped after 0.02 days | 998 sus., 0 inf., 2 rec., 0 tre (nan% of inf) | I(q): 1996 R(q): 0 T(q): 0 |q|: 1996\n", 178 | "29\n", 179 | "Epidemic stopped after 0.05 days | 984 sus., 0 inf., 16 rec., 0 tre (nan% of inf) | I(q): 15819 R(q): 0 T(q): 0 |q|: 15819\n", 180 | "30\n", 181 | "Epidemic stopped after 0.27 days | 432 sus., 0 inf., 568 rec., 0 tre (nan% of inf) | I(q): 397265 R(q): 0 T(q): 0 |q|: 397265\n" 182 | ] 183 | } 184 | ], 185 | "source": [ 186 | "param_dict = {\n", 187 | " 'beta': beta, # Infection rate\n", 188 | " 'delta': delta, # Spontaneous recovery rate\n", 189 | " \n", 190 | " 'gamma': gamma, # Reduction in infection rate from treatement\n", 191 | " 'rho': rho, # Recovery rate from treatement\n", 192 | " 'eta': eta, # Not used\n", 193 | " 'q_x': q_x, # Not used\n", 194 | " 'q_lam': q_lam # Not used\n", 195 | "}\n", 196 | "\n", 197 | "res_1 = {'sus': list(), 'inf': list(), 'rec': list()}\n", 198 | "\n", 199 | "for i in range(30):\n", 200 | " print(i+1)\n", 201 | " sir = SimulationSIR(graph, **param_dict, verbose=True)\n", 202 | " sir.launch_epidemic(init_event_list=init_event_list, max_time=max_time, \n", 203 | " policy=policy_name, policy_dict=policy_dict)\n", 204 | " \n", 205 | " res_1['sus'].append(np.sum(sir.is_sus))\n", 206 | " res_1['inf'].append(np.sum(sir.is_inf * (1 - sir.is_rec)))\n", 207 | " res_1['rec'].append(np.sum(sir.is_rec))" 208 | ] 209 | }, 210 | { 211 | "cell_type": "code", 212 | "execution_count": 20, 213 | "metadata": {}, 214 | "outputs": [ 215 | { 216 | "name": "stdout", 217 | "output_type": "stream", 218 | "text": [ 219 | "sus 724.8333333333334\n", 220 | "inf 0.0\n", 221 | "rec 275.1666666666667\n" 222 | ] 223 | } 224 | ], 225 | "source": [ 226 | "for k in res_1.keys():\n", 227 | " print(k, np.mean(res_1[k]))" 228 | ] 229 | }, 230 | { 231 | "cell_type": "code", 232 | "execution_count": 21, 233 | "metadata": {}, 234 | "outputs": [ 235 | { 236 | "data": { 237 | "text/plain": [ 238 | "(array([ 6., 4., 4., 0., 0., 0., 0., 0., 0., 16.]),\n", 239 | " array([320. , 387.9, 455.8, 523.7, 591.6, 659.5, 727.4, 795.3, 863.2, 931.1, 999. ]),\n", 240 | " )" 241 | ] 242 | }, 243 | "execution_count": 21, 244 | "metadata": {}, 245 | "output_type": "execute_result" 246 | }, 247 | { 248 | "data": { 249 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXQAAAD8CAYAAABn919SAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvOIA7rQAAD41JREFUeJzt3X2MZXddx/H3xw4FtpA+7VAL23GKQg0ShDpgK1KhhbpQQo3BpBvRAjWToCIgSbNIkPBfi0TBaIQNXUoUF7CWh5SHtvJg1eDibB/obrelBdaytWWnVsCHxFL4+sc9bafj7M7ch7lz95f3K5nMOb9zZs4nZ+9+5txzz7k3VYUk6ej3YxsdQJI0Gha6JDXCQpekRljoktQIC12SGmGhS1IjLHRJaoSFLkmNsNAlqRFT49zY5s2ba3Z2dpyblKSj3p49e+6vqunV1htroc/OzrKwsDDOTUrSUS/Jv65lPU+5SFIjLHRJaoSFLkmNsNAlqREWuiQ1YtVCT7IzyaEke5eNvzHJ7Un2JXn3+kWUJK3FWo7QrwS2Lh1I8hLgQuBnq+pngPeMPpokqR+rFnpV3QA8sGz4DcBlVfW/3TqH1iGbJKkPg55DfybwoiS7k/x9kuePMpQkqX+D3ik6BZwEnAU8H/h4kqfXCp84nWQemAeYmZkZNKckDW12+2c2bNsHLrtg3bcx6BH6QeDq6vkq8CNg80orVtWOqpqrqrnp6VXfikCSNKBBC/2TwEsAkjwTOBa4f1ShJEn9W/WUS5JdwIuBzUkOAu8EdgI7u0sZHwQuXul0iyRpfFYt9KradphFrxlxFknSELxTVJIaYaFLUiMsdElqhIUuSY2w0CWpERa6JDXCQpekRljoktQIC12SGmGhS1IjLHRJaoSFLkmNsNAlqREWuiQ1wkKXpEZY6JLUCAtdkhqxaqEn2ZnkUPdxc8uXvTVJJVnxA6IlSeOzliP0K4GtyweTnAacD9w94kySpAGsWuhVdQPwwAqL/gS4FPDDoSVpAgx0Dj3JhcA9VXXLiPNIkgY01e8PJNkE/AG90y1rWX8emAeYmZnpd3OSpDUa5Aj9J4HTgVuSHAC2ADcm+fGVVq6qHVU1V1Vz09PTgyeVJB1R30foVXUr8JSH57tSn6uq+0eYS5LUp7VctrgL+ApwRpKDSS5Z/1iSpH6teoReVdtWWT47sjSSpIF5p6gkNcJCl6RGWOiS1AgLXZIaYaFLUiMsdElqhIUuSY2w0CWpERa6JDXCQpekRljoktQIC12SGmGhS1IjLHRJaoSFLkmNsNAlqREWuiQ1Yi0fQbczyaEke5eM/VGS25N8LcknkpywvjElSatZyxH6lcDWZWPXA8+uqucAXwfeNuJckqQ+rVroVXUD8MCyseuq6qFu9p+BLeuQTZLUh1GcQ3898LnDLUwyn2QhycLi4uIINidJWslQhZ7k7cBDwEcOt05V7aiquaqam56eHmZzkqQjmBr0B5O8FnglcF5V1cgSSZIGMlChJ9kKXAr8UlX9z2gjSZIGsZbLFncBXwHOSHIwySXAnwFPBq5PcnOS969zTknSKlY9Qq+qbSsMX7EOWSRJQ/BOUUlqhIUuSY2w0CWpERa6JDXCQpekRljoktQIC12SGmGhS1IjLHRJaoSFLkmNsNAlqREWuiQ1wkKXpEZY6JLUCAtdkhphoUtSIyx0SWrEWj6CbmeSQ0n2Lhk7Kcn1Se7svp+4vjElSatZyxH6lcDWZWPbgS9U1TOAL3TzkqQNtGqhV9UNwAPLhi8EPtxNfxj4lRHnkiT1adBz6KdU1b3d9H3AKYdbMcl8koUkC4uLiwNuTpK0mqFfFK2qAuoIy3dU1VxVzU1PTw+7OUnSYQxa6N9JcipA9/3Q6CJJkgYxaKF/Gri4m74Y+NRo4kiSBrWWyxZ3AV8BzkhyMMklwGXAy5LcCby0m5ckbaCp1Vaoqm2HWXTeiLNIkobgnaKS1AgLXZIaYaFLUiMsdElqhIUuSY2w0CWpERa6JDXCQpekRljoktQIC12SGmGhS1IjLHRJaoSFLkmNsNAlqREWuiQ1wkKXpEZY6JLUiKEKPclbkuxLsjfJriRPGFUwSVJ/Bi70JE8Dfg+Yq6pnA8cAF40qmCSpP8OecpkCnphkCtgE/NvwkSRJgxi40KvqHuA9wN3AvcD3quq65eslmU+ykGRhcXFx8KSSpCMa5pTLicCFwOnAU4Hjkrxm+XpVtaOq5qpqbnp6evCkkqQjGuaUy0uBb1XVYlX9ALga+IXRxJIk9WuYQr8bOCvJpiQBzgP2jyaWJKlfw5xD3w1cBdwI3Nr9rh0jyiVJ6tPUMD9cVe8E3jmiLJKkIXinqCQ1wkKXpEZY6JLUCAtdkhphoUtSIyx0SWqEhS5JjbDQJakRFrokNcJCl6RGWOiS1AgLXZIaYaFLUiMsdElqhIUuSY2w0CWpEUMVepITklyV5PYk+5OcPapgkqT+DPWJRcD7gM9X1auTHAtsGkEmSdIABi70JMcD5wCvBaiqB4EHRxNLktSvYU65nA4sAh9KclOSDyY5bkS5JEl9GuaUyxRwJvDGqtqd5H3AduAdS1dKMg/MA8zMzAy8sdntnxk86ZAOXHbBhm1bktZqmCP0g8DBqtrdzV9Fr+Afo6p2VNVcVc1NT08PsTlJ0pEMXOhVdR/w7SRndEPnAbeNJJUkqW/DXuXyRuAj3RUu3wReN3wkSdIghir0qroZmBtRFknSELxTVJIaYaFLUiMsdElqhIUuSY2w0CWpERa6JDXCQpekRljoktQIC12SGmGhS1IjLHRJaoSFLkmNsNAlqREWuiQ1wkKXpEZY6JLUCAtdkhoxdKEnOSbJTUmuGUUgSdJgRnGE/iZg/wh+jyRpCEMVepItwAXAB0cTR5I0qGGP0N8LXAr8aARZJElDmBr0B5O8EjhUVXuSvPgI680D8wAzMzODbm5DzW7/zEZHGKsDl12w0REkDWCYI/QXAq9KcgD4KHBukr9avlJV7aiquaqam56eHmJzkqQjGbjQq+ptVbWlqmaBi4AvVtVrRpZMktQXr0OXpEYMfA59qar6MvDlUfwuSdJgPEKXpEZY6JLUCAtdkhphoUtSIyx0SWqEhS5JjbDQJakRFrokNcJCl6RGWOiS1AgLXZIaYaFLUiMsdElqhIUuSY2w0CWpERa6JDXCQpekRgxc6ElOS/KlJLcl2ZfkTaMMJknqzzAfQfcQ8NaqujHJk4E9Sa6vqttGlE2S1IeBj9Cr6t6qurGb/k9gP/C0UQWTJPVnJOfQk8wCzwN2j+L3SZL6N3ShJ3kS8LfAm6vq+yssn0+ykGRhcXFx2M1Jkg5jqEJP8jh6Zf6Rqrp6pXWqakdVzVXV3PT09DCbkyQdwTBXuQS4AthfVX88ukiSpEEMc4T+QuA3gHOT3Nx9vWJEuSRJfRr4ssWq+kcgI8wiSRqCd4pKUiMsdElqhIUuSY2w0CWpERa6JDXCQpekRljoktQIC12SGmGhS1IjLHRJaoSFLkmNsNAlqREWuiQ1wkKXpEZY6JLUCAtdkhphoUtSI4b9kOitSe5IcleS7aMKJUnq3zAfEn0M8OfAy4FnAduSPGtUwSRJ/RnmCP0FwF1V9c2qehD4KHDhaGJJkvo1TKE/Dfj2kvmD3ZgkaQNMrfcGkswD893sfyW5Y4XVNgP3r3eWEWs2cy4fQ5K1a3Y/T5CjLS8chZlz+VCZf2ItKw1T6PcApy2Z39KNPUZV7QB2HOkXJVmoqrkhsoydmcfDzOvvaMsLZj6cYU65/AvwjCSnJzkWuAj49GhiSZL6NfARelU9lOR3gWuBY4CdVbVvZMkkSX0Z6hx6VX0W+OwIchzxlMyEMvN4mHn9HW15wcwrSlWt9zYkSWPgrf+S1IixFHqSJyT5apJbkuxL8q5u/PQku7u3DvhY9+IqSR7fzd/VLZ8dR84Vch+T5KYk1xwleQ8kuTXJzUkWurGTklyf5M7u+4ndeJL8aZf5a0nO3KDMJyS5KsntSfYnOXuSMyc5o9u/D399P8mbJzlzl+Mt3f+9vUl2df8nJ/bxnORNXdZ9Sd7cjU3cPk6yM8mhJHuXjPWdM8nF3fp3Jrl44EBVte5fQIAnddOPA3YDZwEfBy7qxt8PvKGb/m3g/d30RcDHxpFzhdy/D/w1cE03P+l5DwCbl429G9jeTW8HLu+mXwF8rvu3OQvYvUGZPwz8Vjd9LHDCpGdekv0Y4D561whPbGZ6N/x9C3hiN/9x4LWT+ngGng3sBTbRe53v74CfmsR9DJwDnAnsXTLWV07gJOCb3fcTu+kTB8qzAQ+uTcCNwM/Tu8h+qhs/G7i2m74WOLubnurWy5hzbgG+AJwLXNP9I0xs3m7bB/j/hX4HcGo3fSpwRzf9AWDbSuuNMe/xXdFk2fjEZl6W83zgnyY9M4/e1X1S9/i8BvjlSX08A78GXLFk/h3ApZO6j4FZHlvofeUEtgEfWDL+mPX6+RrbOfTu9MXNwCHgeuAbwHer6qFulaVvHfDI2wp0y78HnDyurJ330nsQ/aibP5nJzgtQwHVJ9qR3hy7AKVV1bzd9H3BKNz0Jb91wOrAIfKg7tfXBJMcx2ZmXugjY1U1PbOaqugd4D3A3cC+9x+ceJvfxvBd4UZKTk2yid2R7GhO8j5fpN+fI8o+t0Kvqh1X1XHpHvi8Afnpc2+5XklcCh6pqz0Zn6dMvVtWZ9N4B83eSnLN0YfX+/E/SZU1T9J6u/kVVPQ/4b3pPUR8xgZkB6M43vwr4m+XLJi1zdw73Qnp/QJ8KHAds3dBQR1BV+4HLgeuAzwM3Az9cts5E7ePDGXfOsV/lUlXfBb5E7yneCUkevhZ+6VsHPPK2At3y44F/H2PMFwKvSnKA3rtIngu8b4LzAo8ciVFVh4BP0PvD+Z0kp3bZTqX3DOkxmTsrvnXDOjsIHKyq3d38VfQKfpIzP+zlwI1V9Z1ufpIzvxT4VlUtVtUPgKvpPcYn9vFcVVdU1c9V1TnAfwBfZ7L38VL95hxZ/nFd5TKd5IRu+onAy4D99Ir91d1qFwOf6qY/3c3TLf9i95duLKrqbVW1papm6T2t/mJV/fqk5gVIclySJz88Te/87t5l2ZZn/s3ulfezgO8teZo4FlV1H/DtJGd0Q+cBt01y5iW28ejpFpjszHcDZyXZlCQ8up8n+fH8lO77DPCr9C5OmOR9vFS/Oa8Fzk9yYvds6vxurH9jetHgOcBNwNfolcwfduNPB74K3EXvqevju/EndPN3dcufPo6ch8n+Yh69ymVi83bZbum+9gFv78ZPpvfi7p30rhY4qRsPvQ8o+QZwKzC3Qfv3ucBC99j4JL1X+Sc983H0jliPXzI26ZnfBdze/f/7S+DxE/54/gd6f3RuAc6b1H1M74/6vcAP6D3jvGSQnMDru/19F/C6QfN4p6gkNcI7RSWpERa6JDXCQpekRljoktQIC12SGmGhS1IjLHRJaoSFLkmN+D8r26TqJBXPvwAAAABJRU5ErkJggg==\n", 250 | "text/plain": [ 251 | "
" 252 | ] 253 | }, 254 | "metadata": { 255 | "needs_background": "light" 256 | }, 257 | "output_type": "display_data" 258 | } 259 | ], 260 | "source": [ 261 | "plt.hist(res_1['sus'])" 262 | ] 263 | }, 264 | { 265 | "cell_type": "markdown", 266 | "metadata": {}, 267 | "source": [ 268 | "---" 269 | ] 270 | }, 271 | { 272 | "cell_type": "code", 273 | "execution_count": 22, 274 | "metadata": {}, 275 | "outputs": [], 276 | "source": [ 277 | "from helpers import HelperFunc\n", 278 | "from dynamics_deprecated import SIRDynamicalSystem" 279 | ] 280 | }, 281 | { 282 | "cell_type": "code", 283 | "execution_count": 23, 284 | "metadata": { 285 | "scrolled": true 286 | }, 287 | "outputs": [ 288 | { 289 | "name": "stdout", 290 | "output_type": "stream", 291 | "text": [ 292 | "1\n", 293 | "time 0.00/10.00 | S: 999, I: 0, R: 1, H: 0, lY: 0.00, lW: 0.00, lN: 0.00 | 99.46 iter/s \n", 294 | "2\n", 295 | "time 0.19/10.00 | S: 455, I: 0, R: 545, H: 0, lY: 0.00, lW: 0.00, lN: 0.00 | 100.20 iter/s \n", 296 | "3\n", 297 | "time 0.38/10.00 | S: 466, I: 0, R: 534, H: 0, lY: 0.00, lW: 0.00, lN: 0.00 | 94.57 iter/s \n", 298 | "4\n", 299 | "time 0.01/10.00 | S: 998, I: 0, R: 2, H: 0, lY: 0.00, lW: 0.00, lN: 0.00 | 108.11 iter/s \n", 300 | "5\n", 301 | "time 0.01/10.00 | S: 998, I: 0, R: 2, H: 0, lY: 0.00, lW: 0.00, lN: 0.00 | 129.85 iter/s \n", 302 | "6\n", 303 | "time 0.34/10.00 | S: 456, I: 0, R: 544, H: 0, lY: 0.00, lW: 0.00, lN: 0.00 | 48.27 iter/s \n", 304 | "7\n", 305 | "time 0.29/10.00 | S: 307, I: 0, R: 693, H: 0, lY: 0.00, lW: 0.00, lN: 0.00 | 78.68 iter/s \n", 306 | "8\n", 307 | "time 0.00/10.00 | S: 999, I: 0, R: 1, H: 0, lY: 0.00, lW: 0.00, lN: 0.00 | 93.77 iter/s \n", 308 | "9\n", 309 | "time 0.01/10.00 | S: 999, I: 0, R: 1, H: 0, lY: 0.00, lW: 0.00, lN: 0.00 | 109.71 iter/s \n", 310 | "10\n", 311 | "time 0.04/10.00 | S: 992, I: 0, R: 8, H: 0, lY: 0.00, lW: 0.00, lN: 0.00 | 106.52 iter/s \n", 312 | "11\n", 313 | "time 0.00/10.00 | S: 999, I: 0, R: 1, H: 0, lY: 0.00, lW: 0.00, lN: 0.00 | 82.31 iter/s \n", 314 | "12\n", 315 | "time 0.00/10.00 | S: 999, I: 0, R: 1, H: 0, lY: 0.00, lW: 0.00, lN: 0.00 | 108.44 iter/s \n", 316 | "13\n", 317 | "time 0.00/10.00 | S: 999, I: 0, R: 1, H: 0, lY: 0.00, lW: 0.00, lN: 0.00 | 89.51 iter/s \n", 318 | "14\n", 319 | "time 0.03/10.00 | S: 984, I: 0, R: 16, H: 0, lY: 0.00, lW: 0.00, lN: 0.00 | 99.97 iter/s \n", 320 | "15\n", 321 | "time 0.21/10.00 | S: 334, I: 0, R: 666, H: 0, lY: 0.00, lW: 0.00, lN: 0.00 | 92.17 iter/s \n", 322 | "16\n", 323 | "time 0.30/10.00 | S: 447, I: 0, R: 553, H: 0, lY: 0.00, lW: 0.00, lN: 0.00 | 88.25 iter/s \n", 324 | "17\n", 325 | "time 0.01/10.00 | S: 999, I: 0, R: 1, H: 0, lY: 0.00, lW: 0.00, lN: 0.00 | 102.09 iter/s \n", 326 | "18\n", 327 | "time 0.37/10.00 | S: 356, I: 0, R: 644, H: 0, lY: 0.00, lW: 0.00, lN: 0.00 | 86.31 iter/s \n", 328 | "19\n", 329 | "time 0.28/10.00 | S: 402, I: 0, R: 598, H: 0, lY: 0.00, lW: 0.00, lN: 0.00 | 66.87 iter/s \n", 330 | "20\n", 331 | "time 0.24/10.00 | S: 325, I: 0, R: 675, H: 0, lY: 0.00, lW: 0.00, lN: 0.00 | 93.26 iter/s \n", 332 | "21\n", 333 | "time 0.02/10.00 | S: 998, I: 0, R: 2, H: 0, lY: 0.00, lW: 0.00, lN: 0.00 | 129.42 iter/s \n", 334 | "22\n", 335 | "time 0.32/10.00 | S: 422, I: 0, R: 578, H: 0, lY: 0.00, lW: 0.00, lN: 0.00 | 83.24 iter/s \n", 336 | "23\n", 337 | "time 0.01/10.00 | S: 999, I: 0, R: 1, H: 0, lY: 0.00, lW: 0.00, lN: 0.00 | 80.90 iter/s \n", 338 | "24\n", 339 | "time 0.01/10.00 | S: 999, I: 0, R: 1, H: 0, lY: 0.00, lW: 0.00, lN: 0.00 | 106.21 iter/s \n", 340 | "25\n", 341 | "time 0.42/10.00 | S: 420, I: 0, R: 580, H: 0, lY: 0.00, lW: 0.00, lN: 0.00 | 94.06 iter/s \n", 342 | "26\n", 343 | "time 0.30/10.00 | S: 431, I: 0, R: 569, H: 0, lY: 0.00, lW: 0.00, lN: 0.00 | 79.85 iter/s \n", 344 | "27\n", 345 | "time 0.00/10.00 | S: 999, I: 0, R: 1, H: 0, lY: 0.00, lW: 0.00, lN: 0.00 | 127.75 iter/s \n", 346 | "28\n", 347 | "time 0.06/10.00 | S: 988, I: 0, R: 12, H: 0, lY: 0.00, lW: 0.00, lN: 0.00 | 105.63 iter/s \n", 348 | "29\n", 349 | "time 0.25/10.00 | S: 327, I: 0, R: 673, H: 0, lY: 0.00, lW: 0.00, lN: 0.00 | 22.62 iter/s \n", 350 | "30\n", 351 | "time 0.01/10.00 | S: 999, I: 0, R: 1, H: 0, lY: 0.00, lW: 0.00, lN: 0.00 | 55.94 iter/s \n" 352 | ] 353 | } 354 | ], 355 | "source": [ 356 | "sim_dict = {\n", 357 | " 'total_time': max_time,\n", 358 | " 'trials_per_setting': 1\n", 359 | "}\n", 360 | "param_dict = {\n", 361 | " 'beta': beta, # Infection rate\n", 362 | " 'delta': delta, # Spontaneous recovery rate\n", 363 | " 'gamma': gamma, # Reduction in infection rate from treatement\n", 364 | " 'rho': rho, # Recovery rate from treatement\n", 365 | " 'eta': eta, # Not used\n", 366 | "}\n", 367 | "cost_dict = {\n", 368 | " 'Qlam': q_x,\n", 369 | " 'Qx': q_lam\n", 370 | "}\n", 371 | "\n", 372 | "\n", 373 | "res_2 = {'sus': list(), 'inf': list(), 'rec': list()}\n", 374 | "\n", 375 | "for i in range(30):\n", 376 | " print(i+1)\n", 377 | "\n", 378 | " system = SIRDynamicalSystem(X_init, A, param_dict, cost_dict, min_d0=0.0, verbose=True, notebook=True)\n", 379 | " data = system.simulate_policy(policy_name, policy_dict, sim_dict, plot=False)\n", 380 | "\n", 381 | " h = HelperFunc()\n", 382 | " n_sus = np.sum(1 - h.sps_values(data['Y'], max_time))\n", 383 | " n_inf = np.sum(h.sps_values(data['Y'], max_time) * (1 - h.sps_values(data['W'], max_time)))\n", 384 | " n_rec = np.sum(h.sps_values(data['W'], max_time))\n", 385 | " res_2['sus'].append(n_sus)\n", 386 | " res_2['inf'].append(n_inf)\n", 387 | " res_2['rec'].append(n_rec)" 388 | ] 389 | }, 390 | { 391 | "cell_type": "code", 392 | "execution_count": 24, 393 | "metadata": {}, 394 | "outputs": [ 395 | { 396 | "name": "stdout", 397 | "output_type": "stream", 398 | "text": [ 399 | "sus 736.5\n", 400 | "inf 0.0\n", 401 | "rec 263.5\n" 402 | ] 403 | } 404 | ], 405 | "source": [ 406 | "for k in res_2.keys():\n", 407 | " print(k, np.mean(res_2[k]))" 408 | ] 409 | }, 410 | { 411 | "cell_type": "code", 412 | "execution_count": 25, 413 | "metadata": {}, 414 | "outputs": [ 415 | { 416 | "data": { 417 | "text/plain": [ 418 | "(array([ 5., 4., 4., 0., 0., 0., 0., 0., 0., 17.]),\n", 419 | " array([307. , 376.2, 445.4, 514.6, 583.8, 653. , 722.2, 791.4, 860.6, 929.8, 999. ]),\n", 420 | " )" 421 | ] 422 | }, 423 | "execution_count": 25, 424 | "metadata": {}, 425 | "output_type": "execute_result" 426 | }, 427 | { 428 | "data": { 429 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXQAAAD8CAYAAABn919SAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvOIA7rQAAD4ZJREFUeJzt3X2MZXV9x/H3p6ygi4anHSkK28FWaayxSkcLtVIFpasYaRqbsNEWlWYT21q1JmStqcb/wJpWmzbqBlaMtauW4kPwAagPpW3s2lkE2WVBULe4FNyhVG1tUkS//eOelWG6uzP33jNzL7+8X8lkzvmdM3M+OXv3M+eee+49qSokSY9+PzXpAJKkfljoktQIC12SGmGhS1IjLHRJaoSFLkmNsNAlqREWuiQ1wkKXpEasW8uNbdiwoWZnZ9dyk5L0qLdr1677q2pmufXWtNBnZ2eZn59fy01K0qNekn9byXqecpGkRljoktQIC12SGmGhS1IjLHRJaoSFLkmNsNAlqREWuiQ1wkKXpEas6TtFJWmSZrd+emLb3nfZBau+DY/QJakRFrokNcJCl6RGWOiS1AgLXZIaYaFLUiMsdElqxLKFnmR7kgNJdi8Zf32S25PsSfLO1YsoSVqJlRyhXwVsWjyQ5IXAhcAvVtUvAO/qP5okaRjLFnpV3Qg8sGT4dcBlVfW/3ToHViGbJGkIo55Dfxrw/CQ7k/xDkuf0GUqSNLxRP8tlHXAicBbwHOBjSZ5SVbV0xSRbgC0AGzduHDWnJGkZox6h7weuqYGvAD8GNhxqxaraVlVzVTU3MzMzak5J0jJGLfRPAC8ESPI04Gjg/r5CSZKGt+wplyQ7gBcAG5LsB94ObAe2d5cyPghcfKjTLZKktbNsoVfV5sMselXPWSRJY/CdopLUCAtdkhphoUtSIyx0SWqEhS5JjbDQJakRFrokNcJCl6RGWOiS1AgLXZIaYaFLUiMsdElqhIUuSY2w0CWpERa6JDXCQpekRixb6Em2JznQ3Z1o6bI3J6kkh7yfqCRp7azkCP0qYNPSwSSnAecDd/ecSZI0gmULvapuBB44xKI/By4FvJeoJE2Bkc6hJ7kQuKeqblnBuluSzCeZX1hYGGVzkqQVGLrQk6wH/hh420rWr6ptVTVXVXMzMzPDbk6StEKjHKH/LHA6cEuSfcCpwE1JfrrPYJKk4awb9geq6lbgiQfnu1Kfq6r7e8wlSRrSSi5b3AF8GTgjyf4kl6x+LEnSsJY9Qq+qzcssn+0tjSRpZL5TVJIaYaFLUiMsdElqhIUuSY2w0CWpERa6JDXCQpekRljoktQIC12SGmGhS1IjLHRJaoSFLkmNsNAlqREWuiQ1wkKXpEas5AYX25McSLJ70difJrk9ydeSfDzJ8asbU5K0nJUcoV8FbFoydgPwjKp6JvB14C0955IkDWnZQq+qG4EHloxdX1UPdbP/wuBG0ZKkCerjHPprgc/28HskSWMYq9CTvBV4CPjwEdbZkmQ+yfzCwsI4m5MkHcHIhZ7k1cDLgFdWVR1uvaraVlVzVTU3MzMz6uYkSctYN8oPJdkEXAr8WlX9T7+RJEmjWMllizuALwNnJNmf5BLgL4EnADckuTnJ+1Y5pyRpGcseoVfV5kMMX7kKWSRJY/CdopLUCAtdkhphoUtSIyx0SWqEhS5JjbDQJakRFrokNcJCl6RGWOiS1AgLXZIaYaFLUiMsdElqhIUuSY2w0CWpERa6JDViJTe42J7kQJLdi8ZOTHJDkju77yesbkxJ0nJWcoR+FbBpydhW4PNV9VTg8928JGmCli30qroReGDJ8IXAB7vpDwK/0XMuSdKQRj2HfnJV3dtN3wec3FMeSdKIxn5RtKoKqMMtT7IlyXyS+YWFhXE3J0k6jFEL/TtJTgHovh843IpVta2q5qpqbmZmZsTNSZKWM2qhfwq4uJu+GPhkP3EkSaNayWWLO4AvA2ck2Z/kEuAy4MVJ7gRe1M1LkiZo3XIrVNXmwyw6r+cskqQx+E5RSWqEhS5JjbDQJakRFrokNcJCl6RGWOiS1AgLXZIaYaFLUiMsdElqhIUuSY2w0CWpERa6JDXCQpekRljoktQIC12SGmGhS1Ijxir0JG9KsifJ7iQ7kjy2r2CSpOGMXOhJngz8ITBXVc8AjgIu6iuYJGk4455yWQc8Lsk6YD3w7+NHkiSNYuRCr6p7gHcBdwP3At+rquuXrpdkS5L5JPMLCwujJ5UkHdE4p1xOAC4ETgeeBByb5FVL16uqbVU1V1VzMzMzoyeVJB3ROKdcXgR8q6oWquqHwDXAr/QTS5I0rHEK/W7grCTrkwQ4D9jbTyxJ0rDGOYe+E7gauAm4tftd23rKJUka0rpxfriq3g68vacskqQx+E5RSWqEhS5JjbDQJakRFrokNcJCl6RGWOiS1AgLXZIaYaFLUiMsdElqhIUuSY2w0CWpERa6JDXCQpekRljoktQIC12SGjFWoSc5PsnVSW5PsjfJ2X0FkyQNZ6wbXADvAT5XVa9IcjSwvodMkqQRjFzoSY4DzgFeDVBVDwIP9hNLkjSscU65nA4sAB9I8tUkVyQ5tqdckqQhjVPo64AzgfdW1bOBHwBbl66UZEuS+STzCwsLY2xOknQk4xT6fmB/Ve3s5q9mUPCPUFXbqmququZmZmbG2Jwk6UhGLvSqug/4dpIzuqHzgNt6SSVJGtq4V7m8Hvhwd4XLN4HXjB9JkjSKsQq9qm4G5nrKIkkag+8UlaRGWOiS1AgLXZIaYaFLUiMsdElqhIUuSY2w0CWpEeO+sWjNzG799MS2ve+yCya2bUlaKY/QJakRFrokNcJCl6RGWOiS1AgLXZIaYaFLUiMsdElqxNiFnuSo7ibR1/YRSJI0mj6O0N8A7O3h90iSxjBWoSc5FbgAuKKfOJKkUY17hP5u4FLgxz1kkSSNYeTPcknyMuBAVe1K8oIjrLcF2AKwcePGUTc3UZP8HJlJ8LNrpEencY7Qnwe8PMk+4CPAuUn+eulKVbWtquaqam5mZmaMzUmSjmTkQq+qt1TVqVU1C1wEfKGqXtVbMknSULwOXZIa0cvnoVfVl4Av9fG7JEmj8QhdkhphoUtSIyx0SWqEhS5JjbDQJakRFrokNcJCl6RGWOiS1AgLXZIaYaFLUiMsdElqhIUuSY2w0CWpERa6JDXCQpekRljoktSIkQs9yWlJvpjktiR7kryhz2CSpOGMc8eih4A3V9VNSZ4A7EpyQ1Xd1lM2SdIQxrlJ9L1VdVM3/V/AXuDJfQWTJA2nl3PoSWaBZwM7D7FsS5L5JPMLCwt9bE6SdAhjF3qSxwN/B7yxqr6/dHlVbauquaqam5mZGXdzkqTDGKvQkzyGQZl/uKqu6SeSJGkU41zlEuBKYG9V/Vl/kSRJoxjnCP15wG8D5ya5uft6aU+5JElDGvmyxar6JyA9ZpEkjcF3ikpSIyx0SWqEhS5JjbDQJakRFrokNcJCl6RGWOiS1AgLXZIaYaFLUiMsdElqhIUuSY2w0CWpERa6JDXCQpekRljoktSIcW9BtynJHUnuSrK1r1CSpOGNcwu6o4C/Al4CPB3YnOTpfQWTJA1nnCP05wJ3VdU3q+pB4CPAhf3EkiQNa5xCfzLw7UXz+7sxSdIEjHxP0ZVKsgXY0s3+d5I7hvjxDcD9/adaNU3kzeUTSLIyTezfKWbeVZTLx8r7MytZaZxCvwc4bdH8qd3YI1TVNmDbKBtIMl9Vc6PFW3vmXV3mXV3mXV1rkXecUy7/Cjw1yelJjgYuAj7VTyxJ0rBGPkKvqoeS/AFwHXAUsL2q9vSWTJI0lLHOoVfVZ4DP9JTlUEY6VTNB5l1d5l1d5l1dq543VbXa25AkrQHf+i9JjZhYoSd5bJKvJLklyZ4k7+jGT0+ys/s4gY92L7iS5Jhu/q5u+eyEch+V5KtJrp32vEn2Jbk1yc1J5ruxE5PckOTO7vsJ3XiS/EWX92tJzpxA3uOTXJ3k9iR7k5w9rXmTnNHt14Nf30/yxmnN22V4U/d/bXeSHd3/wWl+/L6hy7onyRu7sanav0m2JzmQZPeisaEzJrm4W//OJBePHKiqJvIFBHh8N/0YYCdwFvAx4KJu/H3A67rp3wPe101fBHx0Qrn/CPgb4NpufmrzAvuADUvG3gls7aa3Apd30y8FPtv9u5wF7JxA3g8Cv9tNHw0cP815F+U+CriPwbXCU5mXwZv+vgU8btHj9tXT+vgFngHsBtYzeK3v74Gfm7b9C5wDnAnsXjQ2VEbgROCb3fcTuukTRsqzlv9IR9gp64GbgF9mcOH9um78bOC6bvo64Oxuel23XtY456nA54FzgWu7f5hpzruP/1/odwCndNOnAHd00+8HNh9qvTXKelxXOFkyPpV5l2Q8H/jnac7Lw+/sPrF7PF4L/Pq0Pn6B3wKuXDT/J8Cl07h/gVkeWehDZQQ2A+9fNP6I9Yb5mug59O70xc3AAeAG4BvAd6vqoW6VxR8n8JOPGuiWfw84aW0T824GD6ofd/MnMd15C7g+ya4M3rELcHJV3dtN3wec3E1P+qMcTgcWgA90p7SuSHIs05t3sYuAHd30VOatqnuAdwF3A/cyeDzuYnofv7uB5yc5Kcl6Bke3pzGl+3eJYTP2ln2ihV5VP6qqZzE48n0u8POTzHMkSV4GHKiqXZPOMoRfraozGXwi5u8nOWfxwhocDkzLZU7rGDx1fW9VPRv4AYOnqz8xZXkB6M45vxz426XLpilvdx73QgZ/OJ8EHAtsmmioI6iqvcDlwPXA54CbgR8tWWdq9u/hrHXGqbjKpaq+C3yRwVO+45McvD5+8ccJ/OSjBrrlxwH/sYYxnwe8PMk+Bp8seS7wninOe/CojKo6AHycwR/N7yQ5pct1CoNnR4/I2znkRzmsov3A/qra2c1fzaDgpzXvQS8Bbqqq73Tz05r3RcC3qmqhqn4IXMPgMT3Nj98rq+qXquoc4D+BrzO9+3exYTP2ln2SV7nMJDm+m34c8GJgL4Nif0W32sXAJ7vpT3XzdMu/0P31WxNV9ZaqOrWqZhk8xf5CVb1yWvMmOTbJEw5OMzjPu3tJrqV5f6d7Jf4s4HuLnjauuqq6D/h2kjO6ofOA26Y17yKbefh0y8Fc05j3buCsJOuThIf371Q+fgGSPLH7vhH4TQYXI0zr/l1s2IzXAecnOaF7JnV+Nza8tXjR4DAvJDwT+CrwNQZF87Zu/CnAV4C7GDyNPaYbf2w3f1e3/CkTzP4CHr7KZSrzdrlu6b72AG/txk9i8MLunQyuHDixGw+DG5Z8A7gVmJvAfn0WMN89Jj7B4BX/ac57LIOj1uMWjU1z3ncAt3f/3z4EHDOtj98uwz8y+KNzC3DeNO5fBn/M7wV+yOBZ5iWjZARe2+3ru4DXjJrHd4pKUiOm4hy6JGl8FrokNcJCl6RGWOiS1AgLXZIaYaFLUiMsdElqhIUuSY34P+QLq+xdDgrVAAAAAElFTkSuQmCC\n", 430 | "text/plain": [ 431 | "
" 432 | ] 433 | }, 434 | "metadata": { 435 | "needs_background": "light" 436 | }, 437 | "output_type": "display_data" 438 | } 439 | ], 440 | "source": [ 441 | "plt.hist(res_2['sus'])" 442 | ] 443 | }, 444 | { 445 | "cell_type": "code", 446 | "execution_count": null, 447 | "metadata": {}, 448 | "outputs": [], 449 | "source": [] 450 | } 451 | ], 452 | "metadata": { 453 | "kernelspec": { 454 | "display_name": "Python 3", 455 | "language": "python", 456 | "name": "python3" 457 | }, 458 | "language_info": { 459 | "codemirror_mode": { 460 | "name": "ipython", 461 | "version": 3 462 | }, 463 | "file_extension": ".py", 464 | "mimetype": "text/x-python", 465 | "name": "python", 466 | "nbconvert_exporter": "python", 467 | "pygments_lexer": "ipython3", 468 | "version": "3.6.8" 469 | } 470 | }, 471 | "nbformat": 4, 472 | "nbformat_minor": 2 473 | } 474 | -------------------------------------------------------------------------------- /lib/dynamics.py: -------------------------------------------------------------------------------- 1 | """ 2 | epidemic_helper.py: Helper module to simulate continuous-time stochastic 3 | SIR epidemics. 4 | Copyright © 2018 — LCA 4 5 | """ 6 | import time 7 | import bisect 8 | import numpy as np 9 | import pandas as pd 10 | import networkx as nx 11 | import scipy 12 | import scipy.optimize 13 | import scipy as sp 14 | import random as rd 15 | import heapq 16 | import collections 17 | import itertools 18 | import os 19 | 20 | from lpsolvers import solve_lp 21 | 22 | from . import maxcut 23 | from .settings import DATA_DIR 24 | 25 | 26 | def sample_seeds(graph, delta, method='data', n_seeds=None, max_date=None, verbose=True): 27 | """ 28 | Extract seeds from the Ebola cases datasets, by choosing either: 29 | * the first `n_seeds`. 30 | * the first seed until the date `max_date`. 31 | For each seed, we then simulate its recovery time and attribute it to a random node in the 32 | corresponding district. We then start the epidemic at the time of infection of the last seed. 33 | Note that some seeds may have already recovered at this time. In this case, they are just 34 | ignored from the simulation altogether. 35 | 36 | Arguments: 37 | --------- 38 | graph : nx.Graph 39 | The graph of individuals in districts. Nodes must have the attribute `district`. 40 | delta : float 41 | Recovery rate of the epidemic process. Used to sample recovery times of seeds. 42 | n_seeds : int 43 | Number of seeds to sample. 44 | max_date : str 45 | Maximum date to sample seeds (max_date is included in sampling). 46 | method : str ('data' or 'random') 47 | Method to sample the seeds. Can be one of: 48 | - 'data': Use the seeds from the dataset and sample recovery time 49 | - 'random': Sample random seeds along with their recovery time 50 | verbose : bool 51 | Indicate whether or not to print seed generation process. 52 | """ 53 | assert (n_seeds is not None) or (max_date is not None), "Either `n_seeds` or `max_date` must be given" 54 | 55 | if method == 'data': 56 | 57 | # Load real data 58 | df = pd.read_csv(os.path.join(DATA_DIR, 'ebola', 'rstb20160308_si_001_cleaned.csv')) 59 | if n_seeds: 60 | df = df.sort_values('infection_timestamp').iloc[:n_seeds] 61 | elif max_date: 62 | df = df[df.infection_date <= max_date].sort_values('infection_timestamp') 63 | # Extract the seed disctricts 64 | seed_names = list(df['district']) 65 | # Extract district name for each node in the graph 66 | node_names = np.array([u for u, d in graph.nodes(data=True)]) 67 | node_districts = np.array([d['district'] for u, d in graph.nodes(data=True)]) 68 | # Get last infection time of seeds (this is time zero for the simulation) 69 | last_inf_time = df.infection_timestamp.max() 70 | # Init list of seed events 71 | init_event_list = list() 72 | for _, row in df.iterrows(): 73 | inf_time = row['infection_timestamp'] 74 | # Sample recovery time 75 | rec_time = inf_time + rd.expovariate(delta) - last_inf_time 76 | # Ignore seed if recovered before time zero 77 | if rec_time > 0: 78 | # Randomly sample one node for each seed in the corresponding district 79 | idx = np.random.choice(np.where(node_districts == row['district'])[0]) 80 | node = node_names[idx] 81 | # Add infection event 82 | # node to node infection flags initial seeds in code 83 | init_event_list.append([(node, 'inf', node), 0.0]) # Gets infection at the start 84 | # Add recovery event 85 | init_event_list.append([(node, 'rec', None), rec_time]) 86 | if verbose: 87 | print(f'Add seed {node} from district {row["district"]} - inf: {0.0}, rec: {rec_time} ') 88 | return init_event_list 89 | 90 | elif method == 'random': 91 | 92 | if n_seeds is None: 93 | raise ValueError("`n_seeds` must be provided for method `random`") 94 | 95 | init_event_list = list() 96 | for _ in range(n_seeds): 97 | node = np.random.choice(graph.nodes()) 98 | init_event_list.append([(node, 'inf', node), 0.0]) 99 | rec_time = rd.expovariate(delta) 100 | init_event_list.append([(node, 'rec', None), rec_time]) 101 | 102 | return init_event_list 103 | 104 | else: 105 | raise ValueError('Invalid method.') 106 | 107 | 108 | class PriorityQueue(object): 109 | """ 110 | PriorityQueue with O(1) update and deletion of objects 111 | """ 112 | 113 | def __init__(self, initial=[], priorities=[]): 114 | 115 | self.pq = [] 116 | self.entry_finder = {} # mapping of tasks to entries 117 | self.removed = '' # placeholder for a removed task 118 | self.counter = itertools.count() # unique sequence count 119 | 120 | assert(len(initial) == len(priorities)) 121 | for i in range(len(initial)): 122 | self.push(initial[i], priority=priorities[i]) 123 | 124 | def push(self, task, priority=0): 125 | """Add a new task or update the priority of an existing task""" 126 | if task in self.entry_finder: 127 | self.delete(task) 128 | count = next(self.counter) 129 | entry = [priority, count, task] 130 | self.entry_finder[task] = entry 131 | heapq.heappush(self.pq, entry) 132 | 133 | def delete(self, task): 134 | """Mark an existing task as REMOVED. Raise KeyError if not found.""" 135 | entry = self.entry_finder.pop(task) 136 | entry[-1] = self.removed 137 | 138 | def remove_all_tasks_of_type(self, type): 139 | """Removes all existing tasks of a specific type (for SIRSimulation)""" 140 | keys = list(self.entry_finder.keys()) 141 | for event in keys: 142 | u, type_, v = event 143 | if type_ == type: 144 | self.delete(event) 145 | 146 | def pop_priority(self): 147 | """ 148 | Remove and return the lowest priority task with its priority value. 149 | Raise KeyError if empty. 150 | """ 151 | while self.pq: 152 | priority, _, task = heapq.heappop(self.pq) 153 | if task is not self.removed: 154 | del self.entry_finder[task] 155 | return task, priority 156 | raise KeyError('pop from an empty priority queue') 157 | 158 | def pop(self): 159 | """ 160 | Remove and return the lowest priority task. Raise KeyError if empty. 161 | """ 162 | task, _ = self.pop_priority() 163 | return task 164 | 165 | def priority(self, task): 166 | """Return priority of task""" 167 | if task in self.entry_finder: 168 | return self.entry_finder[task][0] 169 | else: 170 | raise KeyError('task not in queue') 171 | 172 | def __len__(self): 173 | return len(self.entry_finder) 174 | 175 | def __str__(self): 176 | return str(self.pq) 177 | 178 | def __repr__(self): 179 | return repr(self.pq) 180 | 181 | def __setitem__(self, task, priority): 182 | self.push(task, priority=priority) 183 | 184 | 185 | class ProgressPrinter(object): 186 | """ 187 | Helper object to print relevant information throughout the epidemic 188 | """ 189 | PRINT_INTERVAL = 0.1 190 | _PRINT_MSG = ('{t:.2f} days elapsed ' 191 | '| ' 192 | '{S:.0f} sus., ' 193 | '{I:.0f} inf., ' 194 | '{R:.0f} rec., ' 195 | '{Tt:.0f} tre ({TI:.2f}% of inf) | ' 196 | # 'I(q): {iq} R(q): {rq} T(q): {tq} |q|: {lq} | ' 197 | 'max_u {max_u:.2e}' 198 | ) 199 | _PRINTLN_MSG = ('Epidemic stopped after {t:.2f} days ' 200 | '| ' 201 | '{S:.0f} sus., ' 202 | '{I:.0f} inf., ' 203 | '{R:.0f} rec., ' 204 | '{Tt:.0f} tre ({TI:.2f}% of inf) | ' 205 | # 'I(q): {iq} R(q): {rq} T(q): {tq} |q|: {lq}' 206 | 'max_u {max_u:.2e}' 207 | ) 208 | 209 | def __init__(self, verbose=True): 210 | self.verbose = verbose 211 | self.last_print = time.time() 212 | 213 | def print(self, sir_obj, epitime, end='', force=False): 214 | if not self.verbose: 215 | return 216 | if (time.time() - self.last_print > self.PRINT_INTERVAL) or force: 217 | S = np.sum(sir_obj.is_sus) 218 | I = np.sum(sir_obj.is_inf * (1 - sir_obj.is_rec)) 219 | R = np.sum(sir_obj.is_rec) 220 | T = np.sum(sir_obj.is_tre) 221 | 222 | Tt = np.sum(sir_obj.is_tre) 223 | TI = 100. * T / I if I > 0 else np.nan 224 | 225 | iq = sir_obj.infs_in_queue 226 | rq = sir_obj.recs_in_queue 227 | tq = sir_obj.tres_in_queue 228 | lq = len(sir_obj.queue) 229 | 230 | print('\r', self._PRINT_MSG.format(t=epitime, S=S, I=I, R=R, Tt=Tt, TI=TI, 231 | max_u=sir_obj.max_total_control_intensity), 232 | sep='', end='', flush=True) 233 | self.last_print = time.time() 234 | 235 | def println(self, sir_obj, epitime): 236 | if not self.verbose: 237 | return 238 | S = np.sum(sir_obj.is_sus) 239 | I = np.sum(sir_obj.is_inf * (1 - sir_obj.is_rec)) 240 | R = np.sum(sir_obj.is_rec) 241 | T = np.sum(sir_obj.is_tre) 242 | 243 | Tt = np.sum(sir_obj.is_tre) 244 | TI = 100. * T / I if I > 0 else np.nan 245 | 246 | iq = sir_obj.infs_in_queue 247 | rq = sir_obj.recs_in_queue 248 | tq = sir_obj.tres_in_queue 249 | lq = len(sir_obj.queue) 250 | 251 | print('\r', self._PRINTLN_MSG.format( 252 | t=epitime, S=S, I=I, R=R, Tt=Tt, TI=TI, 253 | max_u=sir_obj.max_total_control_intensity), 254 | sep='', end='\n', flush=True) 255 | self.last_print = time.time() 256 | 257 | 258 | class SimulationSIR(object): 259 | """ 260 | Simulate continuous-time SIR epidemics with treatement, with exponentially distributed 261 | inter-event times. 262 | 263 | The simulation algorithm works by leveraging the Markov property of the model and rejection 264 | sampling. Events are treated in order in a priority queue. An event in the queue is a tuple 265 | the form 266 | `(node, event_type, infector_node)` 267 | where elements are as follows: 268 | `node` : is the node where the event occurs, 269 | `event_type` : is the type of event (i.e. infected 'inf', recovery 'rec', or treatement 'tre') 270 | `infector_node` : for infections only, the node of caused the infection. 271 | """ 272 | 273 | AVAILABLE_LPSOLVERS = ['scipy', 'cvxopt'] 274 | 275 | def __init__(self, G, beta, gamma, delta, rho, verbose=True): 276 | """ 277 | Init an SIR cascade over a graph 278 | 279 | Arguments: 280 | --------- 281 | G : networkx.Graph() 282 | Graph over which the epidemic propagates 283 | beta : float 284 | Exponential infection rate (positive) 285 | gamma : float 286 | Reduction in infection rate by treatment 287 | delta : float 288 | Exponential recovery rate (non-negative) 289 | rho : float 290 | Increase in recovery rate by treatment 291 | verbose : bool (default: True) 292 | Indicate the print behavior, if set to False, nothing will be printed 293 | """ 294 | if not isinstance(G, nx.Graph): 295 | raise ValueError('Invalid graph type, must be networkx.Graph') 296 | self.G = G 297 | self.A = sp.sparse.csr_matrix(nx.adjacency_matrix(self.G).toarray()) 298 | 299 | # Cache the number of nodes 300 | self.n_nodes = len(G.nodes()) 301 | self.max_deg = np.max([d for n, d in self.G.degree()]) 302 | self.min_deg = np.min([d for n, d in self.G.degree()]) 303 | self.idx_to_node = dict(zip(range(self.n_nodes), self.G.nodes())) 304 | self.node_to_idx = dict(zip(self.G.nodes(), range(self.n_nodes))) 305 | 306 | # Check parameters 307 | if isinstance(beta, (float, int)) and (beta > 0): 308 | self.beta = beta 309 | else: 310 | raise ValueError("`beta` must be a positive float") 311 | if isinstance(gamma, (float, int)) and (gamma >= 0) and (gamma <= beta): 312 | self.gamma = gamma 313 | else: 314 | raise ValueError(("`gamma` must be a positive float smaller than `beta`")) 315 | if isinstance(delta, (float, int)) and (delta >= 0): 316 | self.delta = delta 317 | else: 318 | raise ValueError("`delta` must be a non-negative float") 319 | if isinstance(rho, (float, int)) and (rho >= 0): 320 | self.rho = rho 321 | else: 322 | raise ValueError("`rho` must be a non-negative float") 323 | 324 | # Control pre-computations 325 | self.lrsr_initiated = False # flag for initial LRSR computation 326 | self.mcm_initiated = False # flag for initial MCM computation 327 | 328 | # Control statistics 329 | self.max_total_control_intensity = 0.0 330 | 331 | # Printer for logging 332 | self._printer = ProgressPrinter(verbose=verbose) 333 | 334 | def expo(self, rate): 335 | """Samples a single exponential random variable.""" 336 | return rd.expovariate(rate) 337 | 338 | def nodes_at_time(self, status, time): 339 | """ 340 | Get the status of all nodes at a given time 341 | """ 342 | if status == 'S': 343 | return self.inf_occured_at > time 344 | elif status == 'I': 345 | return (self.rec_occured_at > time) * (self.inf_occured_at < time) 346 | elif status == 'T': 347 | return (self.tre_occured_at < time) * (self.rec_occured_at > time) 348 | elif status == 'R': 349 | return self.rec_occured_at < time 350 | else: 351 | raise ValueError('Invalid status.') 352 | 353 | def _init_run(self, init_event_list, max_time): 354 | """ 355 | Initialize the run of the epidemic 356 | """ 357 | 358 | # Max time of the run 359 | self.max_time = max_time 360 | 361 | # Priority queue of events by time 362 | # event invariant is ('node', event, 'node') where the second node is the infector if applicable 363 | self.queue = PriorityQueue() 364 | # Cache the number of ins, recs, tres in the queue 365 | self.infs_in_queue = 0 366 | self.recs_in_queue = 0 367 | self.tres_in_queue = 0 368 | 369 | # Susceptible nodes tracking: is_sus[node]=1 if node is currently susceptible) 370 | self.initial_seed = np.zeros(self.n_nodes, dtype='bool') 371 | self.is_sus = np.ones(self.n_nodes, dtype='bool') # True if u susceptible 372 | 373 | # Infection tracking: is_inf[node]=1 if node has been infected 374 | # (note that the node can be already recovered) 375 | self.inf_occured_at = np.inf * np.ones(self.n_nodes, dtype='float') # time infection of u_idx occurred 376 | self.is_inf = np.zeros(self.n_nodes, dtype='bool') # True if u_idx infected 377 | self.infector = -1 * np.ones(self.n_nodes, dtype='int') # index of node that infected u_idx (if -1, then no infector) 378 | self.num_child_inf = np.zeros(self.n_nodes, dtype='int') # number of neighbors u_idx infected 379 | 380 | # Recovery tracking: is_rec[node]=1 if node is currently recovered 381 | self.rec_occured_at = np.inf * np.ones(self.n_nodes, dtype='float') # time recovery of u_idx occured 382 | self.is_rec = np.zeros(self.n_nodes, dtype='bool') # True if u_idx recovered 383 | 384 | # Treatment tracking: is_tre[node]=1 if node is currently treated 385 | self.tre_occured_at = np.inf * np.ones(self.n_nodes, dtype='float') # time treatment of u_idx occured 386 | self.is_tre = np.zeros(self.n_nodes, dtype='bool') # True if u_idx treated 387 | 388 | # Conrol tracking 389 | self.old_lambdas = np.zeros(self.n_nodes, dtype='float') # control intensity of prev iter 390 | self.max_interventions_reached = False 391 | 392 | # Add the initial events to priority queue 393 | for event, time in init_event_list: 394 | u, event_type, _ = event 395 | u_idx = self.node_to_idx[u] 396 | self.initial_seed[u_idx] = True 397 | if event_type == 'inf': 398 | # Initial infections have infections from NaN to u 399 | self.queue.push(event, priority=time) 400 | self.infs_in_queue += 1 401 | elif event_type == 'rec': 402 | self.queue.push(event, priority=time) 403 | self.recs_in_queue += 1 404 | else: 405 | raise ValueError('Invalid Event Type for initial seeds.') 406 | 407 | def _process_infection_event(self, u, time, w): 408 | """ 409 | Mark node `u` as infected at time `time` 410 | Sample its recovery time and its neighbors infection times and add to the queue 411 | """ 412 | # Get node index 413 | u_idx = self.node_to_idx[u] 414 | # Handle infection event 415 | self.is_inf[u_idx] = True 416 | self.is_sus[u_idx] = False 417 | self.inf_occured_at[u_idx] = time 418 | if self.initial_seed[u_idx]: 419 | # Handle initial seeds 420 | self.infector[u_idx] = -1 421 | else: 422 | w_idx = self.node_to_idx[w] 423 | self.infector[u_idx] = w_idx 424 | self.num_child_inf[w_idx] += 1 425 | recovery_time_u = time + self.expo(self.delta) 426 | self.queue.push((u, 'rec', None), priority=recovery_time_u) 427 | self.recs_in_queue += 1 428 | # Set neighbors infection events 429 | for v in self.G.neighbors(u): 430 | v_idx = self.node_to_idx[v] 431 | if self.is_sus[v_idx]: 432 | infection_time_v = time + self.expo(self.beta) 433 | self.queue.push((v, 'inf', u), priority=infection_time_v) 434 | self.infs_in_queue += 1 435 | 436 | def _process_recovery_event(self, u, time): 437 | """ 438 | Mark node `node` as recovered at time `time` 439 | """ 440 | # Get node index 441 | u_idx = self.node_to_idx[u] 442 | # Handle recovery event 443 | self.rec_occured_at[u_idx] = time 444 | self.is_rec[u_idx] = True 445 | 446 | def _process_treatment_event(self, u, time): 447 | """ 448 | Mark node `u` as treated at time `time` 449 | Update its recovery time and its neighbors infection times and the queue 450 | """ 451 | # Get node index 452 | u_idx = self.node_to_idx[u] 453 | # Handle treatement event 454 | self.tre_occured_at[u_idx] = time 455 | self.is_tre[u_idx] = True 456 | # Update own recovery event with rejection sampling 457 | assert(self.rho <= 0) 458 | if rd.random() < - self.rho / self.delta: 459 | # reject previous event 460 | self.queue.delete((u, 'rec', None)) 461 | # re-sample 462 | new_recovery_time_u = time + self.expo(self.delta + self.rho) 463 | self.queue.push((u, 'rec', None), priority=new_recovery_time_u) 464 | # Update neighbors infection events triggered by u 465 | for v in self.G.neighbors(u): 466 | v_idx = self.node_to_idx[v] 467 | if self.is_sus[v_idx]: 468 | if rd.random() < self.gamma / self.beta: 469 | # reject previous event 470 | self.queue.delete((v, 'inf', u)) 471 | # re-sample 472 | if self.beta - self.gamma > 0: 473 | new_infection_time_v = time + self.expo(self.beta - self.gamma) 474 | else: 475 | # Avoid DivisionByZeroError if beta = gamma 476 | # i.e., if no infectivity under treatement, then set infection time to inf 477 | # We still set an event to make the algo easier and avoid bugs 478 | new_infection_time_v = np.inf 479 | self.queue.push((v, 'inf', u), priority=new_infection_time_v) 480 | 481 | def _control(self, u, time, policy='NO'): 482 | # Get node index 483 | u_idx = self.node_to_idx[u] 484 | # Check if max interventions were reached (for FL) 485 | if '-FL' in policy: 486 | max_interventions = self.policy_dict['front-loading']['max_interventions'] 487 | current_interventions = np.sum(self.is_tre) 488 | if current_interventions > max_interventions: 489 | # End interventions for this simulation 490 | self.max_interventions_reached = True 491 | self.queue.remove_all_tasks_of_type('tre') 492 | print('All treatments ended') 493 | return 494 | # Compute control intensity 495 | self.new_lambda = self._compute_lambda(u, time, policy=policy) 496 | # Sample treatment event 497 | delta = self.new_lambda - self.old_lambdas[u_idx] 498 | if delta < 0: 499 | # Update treatment event with rejection sampling as intensity was reduced 500 | if rd.random() < 1 - self.new_lambda / self.old_lambdas[u_idx]: 501 | # reject previous event 502 | self.queue.delete((u, 'tre', None)) 503 | if self.new_lambda > 0: 504 | # re-sample 505 | new_treatment_time_u = time + self.expo(self.new_lambda) 506 | self.queue.push((u, 'tre', None), priority=new_treatment_time_u) 507 | elif delta > 0: 508 | # Sample new/additional treatment event with the superposition principle 509 | new_treatment_time_u = time + self.expo(delta) 510 | self.queue.push((u, 'tre', None), priority=new_treatment_time_u) 511 | self.tres_in_queue += 1 512 | # store lambda 513 | self.old_lambdas[u_idx] = self.new_lambda 514 | 515 | def _compute_lambda(self, u, time, policy='NO'): 516 | """Computes control intensity of the respective policy""" 517 | 518 | if policy == 'NO': 519 | return 0.0 520 | 521 | elif policy == 'TR': 522 | # lambda = const. 523 | return self.policy_dict['TR'] 524 | 525 | elif policy == 'TR-FL': 526 | return self._frontloadPolicy( 527 | self.policy_dict['TR'], 528 | self.policy_dict['TR'], time) 529 | 530 | elif policy == 'MN': 531 | # lambda ~ deg(u) 532 | return self.G.degree(u) * self.policy_dict['MN'] 533 | 534 | elif policy == 'MN-FL': 535 | return self._frontloadPolicy( 536 | self.G.degree(u) * self.policy_dict['MN'], 537 | self.max_deg * self.policy_dict['MN'], time) 538 | 539 | elif policy == 'LN': 540 | # lambda ~ (maxdeg - deg(u) + 1) 541 | return (self.max_deg - self.G.degree(u) + 1) * self.policy_dict['LN'] 542 | 543 | elif policy == 'LN-FL': 544 | return self._frontloadPolicy( 545 | (self.max_deg - self.G.degree(u) + 1) * self.policy_dict['LN'], 546 | (self.max_deg - self.min_deg + 1) * self.policy_dict['LN'], time) 547 | 548 | elif policy == 'LRSR': 549 | # lambda ~ 1/rank 550 | # where rank is order of largest reduction in spectral radius of A 551 | intensity, _ = self._compute_LRSR_lambda(u, time) 552 | return intensity 553 | 554 | elif policy == 'LRSR-FL': 555 | intensity, max_intensity = self._compute_LRSR_lambda(u, time) 556 | return self._frontloadPolicy( 557 | intensity, max_intensity, time) 558 | 559 | elif policy == 'MCM': 560 | # lambda ~ 1/rank 561 | # where rank is MCM heuristic ranking 562 | intensity, _ = self._compute_MCM_lambda(u, time) 563 | return intensity 564 | 565 | elif policy == 'MCM-FL': 566 | intensity, max_intensity = self._compute_MCM_lambda(u, time) 567 | return self._frontloadPolicy( 568 | intensity, max_intensity, time) 569 | 570 | elif policy == 'SOC': 571 | return self._compute_SOC_lambda(u, time) 572 | 573 | else: 574 | raise KeyError('Invalid policy code.') 575 | 576 | def _frontloadPolicy(self, intensity, max_intensity, time): 577 | """ 578 | Return front-loaded variation of policy u at time t 579 | Scales a given `intensity` such that the policy's current 580 | `max_intensity` is equal to the SOC's `max_lambda` 581 | """ 582 | max_lambda = self.policy_dict['front-loading']['max_lambda'] 583 | 584 | # scale proportionally s.t. max(u) = max(u_SOC) 585 | if max_intensity > 0.0: 586 | return max_lambda * intensity / max_intensity 587 | else: 588 | return 0.0 589 | 590 | def _compute_LRSR_lambda(self, u, time): 591 | 592 | # TODO 593 | # raise ValueError('Currently too slow for big networks. Eigenvalues of A need to be found |V| times using brute force.') 594 | 595 | # lambda ~ 1/rank 596 | # where rank is order of largest reduction in spectral radius of A 597 | if self.lrsr_initiated: 598 | intensity = 1.0 / (1.0 + np.where(self.spectral_ranking == u)[0]) * self.policy_dict['LRSR'] 599 | max_intensity = self.policy_dict['LRSR'] 600 | 601 | # return both u's intensity and max intensity of all nodes for potential FL 602 | return intensity, max_intensity 603 | else: 604 | # first time: compute ranking for all nodes 605 | def spectral_radius(A): # TODO not tested yet 606 | return np.max(scipy.linalg.eigvalsh(self.A, turbo=True, eigvals=(self.n_nodes - 2, self.n_nodes - 1))) 607 | 608 | # Brute force: 609 | # find which node removals reduce spectral radius the most 610 | tau = spectral_radius(A) 611 | reduction_by_node = np.zeros(self.n_nodes) 612 | 613 | last_print = time.time() 614 | for n in range(self.n_nodes): 615 | A_ = np.copy(A) 616 | A_[n, :] = np.zeros(self.n_nodes) 617 | A_[:, n] = np.zeros(self.n_nodes) 618 | reduction_by_node[n] = tau - spectral_radius(A_) 619 | 620 | # printing 621 | if (time.time() - last_print > 0.1): 622 | last_print = time.time() 623 | done = 100 * n / self.n_nodes 624 | print('\r', f'Computing LRSR ranking... {done:.2f}%', 625 | sep='', end='', flush=True) 626 | 627 | order = np.argsort(reduction_by_node) 628 | self.spectral_ranking_idx = np.flip(order) 629 | self.spectral_ranking = np.vectorize(self.idx_to_node.get)(self.spectral_ranking_idx) 630 | self.lrsr_initiated = True 631 | 632 | intensity = 1.0 / (1.0 + np.where(self.spectral_ranking == u)) * self.policy_dict['LRSR'] 633 | max_intensity = self.policy_dict['LRSR'] 634 | 635 | # return both u's intensity and max intensity of all nodes for potential FL 636 | return intensity, max_intensity 637 | 638 | def _compute_MCM_lambda(self, u, time): 639 | """ 640 | Return the adapted heuristic policy MaxCutMinimzation (MCM) at 641 | time `t`. The method is adapted to fit the setup where treatment 642 | intensity `rho` is the equal for everyone, and the control is made on 643 | the rate of intervention, not the intensity of the treatment itself. 644 | """ 645 | 646 | # # TODO 647 | if self.n_nodes > 5000: 648 | raise ValueError('Currently too slow for big networks. Eigenvalues of A needed.') 649 | 650 | if self.mcm_initiated: 651 | intensity = 1.0 / (1.0 + np.where(self.mcm_ranking == u)[0]) * self.policy_dict['MCM'] 652 | max_intensity = self.policy_dict['MCM'] 653 | 654 | # return both u's intensity and max intensity of all nodes for potential FL 655 | return intensity, max_intensity 656 | 657 | else: 658 | # first time: compute ranking for all nodes 659 | order = maxcut.mcm(self.A) 660 | self.mcm_ranking_idx = np.flip(order) 661 | 662 | self.mcm_ranking = np.vectorize(self.idx_to_node.get)(self.mcm_ranking_idx) 663 | self.mcm_initiated = True 664 | 665 | intensity = 1.0 / (1.0 + np.where(self.mcm_ranking == u)[0]) * self.policy_dict['MCM'] 666 | max_intensity = self.policy_dict['MCM'] 667 | 668 | # return both u's intensity and max intensity of all nodes for potential FL 669 | return intensity, max_intensity 670 | 671 | def _compute_SOC_lambda(self, u, time): 672 | """ 673 | Compute the stochastic optimal control rate for node `u` at time `time` 674 | """ 675 | 676 | # Set the variable `d`, where elements corresponding to susceptible individuals come from 677 | # the pre-computed linear program, and the others are zero. 678 | d = np.zeros(self.n_nodes) 679 | d[self.lp_d_S_idx] = self.lp_d_S 680 | 681 | K1 = self.beta * (2 * self.delta + self.eta + self.rho) 682 | K2 = self.beta * (self.delta + self.eta) * (self.delta + self.eta + self.rho) * self.q_lam 683 | K3 = self.eta * (self.gamma * (self.delta + self.eta) + self.beta * (self.delta + self.rho)) 684 | K4 = self.beta * (self.delta + self.rho) * self.q_x 685 | 686 | cache = float(np.dot(self.A[self.node_to_idx[u]].toarray(), d)) 687 | intensity = - 1.0 / (K1 * self.q_lam) * (K2 - np.sqrt(2.0 * K1 * self.q_lam * (K3 * cache + K4) + K2 ** 2.0)) 688 | 689 | if intensity < 0.0: 690 | raise ValueError("Control intensity has to be non-negative.") 691 | 692 | return intensity 693 | 694 | def _update_LP_sol(self): 695 | """ 696 | Update the solution `d_S` of the linear program for the SOC policy. The solution is then 697 | cached in attribute `lp_d_S` along with the corresponding indices `lp_d_S_idx` of 698 | susceptible individuals. 699 | 700 | Note: To speed up convergence of the linear program optimization, we use the Epigraph 701 | trick and transform the equality contraint into an inequality. 702 | """ 703 | 704 | # find subarrays 705 | x_S = np.where(self.is_sus)[0] # Indices of susceptible individuals 706 | x_I = np.where(self.is_inf)[0] # Indices of infected/recovered individuals 707 | len_S = x_S.shape[0] 708 | len_I = x_I.shape[0] 709 | A_IS = self.A[np.ix_(x_I, x_S)] 710 | 711 | K3 = self.eta * (self.gamma * (self.delta + self.eta) + self.beta * (self.delta + self.rho)) 712 | K4 = self.beta * (self.delta + self.rho) * self.q_x 713 | 714 | # objective: c^T x 715 | c = np.hstack((np.ones(len_I), np.zeros(len_S))) 716 | 717 | # inequality: Ax <= b 718 | A_ineq = sp.sparse.hstack([ 719 | sp.sparse.csr_matrix((len_I, len_I)), 720 | - A_IS 721 | ]) 722 | 723 | # equality: Ax = b 724 | A_eq = sp.sparse.hstack([ 725 | - sp.sparse.eye(len_I), 726 | A_IS 727 | ]) 728 | 729 | A = sp.sparse.vstack([A_ineq, A_eq]) 730 | 731 | b = np.hstack([ 732 | K4 / K3 * np.ones(len_I) - 1e-8, # b_ineq 733 | -K4 / K3 * np.ones(len_I) + 1e-8 # b_eq 734 | ]) 735 | 736 | bounds = tuple([(0.0, None)] * len_I + [(None, None)] * len_S) 737 | 738 | if self.lpsolver == 'scipy': 739 | 740 | result = scipy.optimize.linprog( 741 | c, A_ub=A, b_ub=b, 742 | bounds=bounds, 743 | method="interior-point", 744 | options={'tol': 1e-8}) 745 | 746 | if result['success']: 747 | d_S = result['x'][len_I:] 748 | else: 749 | raise Exception("LP couldn't be solved.") 750 | 751 | elif self.lpsolver == 'cvxopt': 752 | 753 | A_dense = A.toarray() 754 | res = solve_lp(c, A_dense, b, None, None) 755 | d_S = res[len_I:] 756 | 757 | else: 758 | raise KeyError('Invalid LP Solver.') 759 | 760 | # store LP solution 761 | self.lp_d_S = d_S 762 | self.lp_d_S_idx = x_S 763 | 764 | def launch_epidemic(self, init_event_list, max_time=np.inf, policy='NO', policy_dict={}, stop_criteria=None): 765 | """ 766 | Run the epidemic, starting from initial event list, for at most `max_time` units of time 767 | """ 768 | 769 | self._init_run(init_event_list, max_time) 770 | self.policy = policy 771 | self.policy_dict = policy_dict 772 | 773 | # Set SOC control parameters 774 | # TODO: Handle policy parameters better 775 | if policy == 'SOC': 776 | self.eta = policy_dict['eta'] 777 | self.q_x = policy_dict['q_x'] 778 | self.q_lam = policy_dict['q_lam'] 779 | if policy_dict.get('lpsolver') in self.AVAILABLE_LPSOLVERS: 780 | self.lpsolver = policy_dict['lpsolver'] 781 | else: 782 | raise ValueError("Invalid `lpsolver`") 783 | 784 | time = 0.0 785 | 786 | while self.queue: 787 | # Get the next event to process 788 | (u, event_type, w), time = self.queue.pop_priority() 789 | 790 | # Update queue cache 791 | if event_type == 'inf': 792 | self.infs_in_queue -= 1 793 | elif event_type == 'rec': 794 | self.recs_in_queue -= 1 795 | elif event_type == 'tre': 796 | self.tres_in_queue -= 1 797 | 798 | # Get node index 799 | u_idx = self.node_to_idx[u] 800 | 801 | # Stop at the end of the observation window 802 | if time > self.max_time: 803 | time = self.max_time 804 | break 805 | 806 | # Process the event 807 | # Check validity of infection event (node u is not infected yet) 808 | if (event_type == 'inf') and (not self.is_inf[u_idx]): 809 | assert self.is_sus[u_idx], f"Node `{u}` should be susceptible to be infected" 810 | w_idx = self.node_to_idx[w] 811 | if self.initial_seed[u_idx] or (not self.is_rec[w_idx]): 812 | self._process_infection_event(u, time, w) 813 | # Check validity of recovery event (node u is not recovered yet) 814 | elif (event_type == 'rec') and (not self.is_rec[u_idx]): 815 | assert self.is_inf[u_idx], f"Node `{u}` should be infected to be recovered" 816 | self._process_recovery_event(u, time) 817 | # Check validity of treatement event (node u is not treated yet, and not recovered) 818 | elif (event_type == 'tre') and (not self.is_tre[u_idx]) and (not self.is_rec[u_idx]): 819 | assert self.is_inf[u_idx], f"Node `{u}` should be infected to be treated" 820 | self._process_treatment_event(u, time) 821 | 822 | # If no-one is infected, the epidemic is finished. Stop the simulation. 823 | if np.sum(self.is_inf * (1 - self.is_rec)) == 0: 824 | break 825 | 826 | if stop_criteria: 827 | if stop_criteria(self): 828 | break 829 | 830 | # Update Control for infected nodes still untreated 831 | if not self.max_interventions_reached: 832 | controlled_nodes = np.where(self.is_inf * (1 - self.is_rec) * (1 - self.is_tre))[0] 833 | if self.policy == 'SOC': 834 | self._update_LP_sol() 835 | for u_idx in controlled_nodes: 836 | self._control(self.idx_to_node[u_idx], time, policy=self.policy) 837 | self.max_total_control_intensity = max( 838 | self.max_total_control_intensity, self.old_lambdas.sum()) 839 | 840 | self._printer.print(self, time) 841 | 842 | self._printer.println(self, time) 843 | 844 | # Free memory 845 | del self.queue 846 | --------------------------------------------------------------------------------