├── setup.cfg ├── requirements-tdd.txt ├── figs ├── r0_hist.png ├── day_epicurve.png ├── mers_r0_hist.png ├── MERS_casetree.png ├── generation_hist.png ├── mers_casetree.png ├── month_epicurve.png ├── steppe_exmple.png ├── year_epicurve.png ├── example_case_tree.png ├── oddsratio_example.png ├── test_checkerboard.png ├── mers_generation_hist.png ├── mers_sex_generation.png ├── rollingprop_example.png ├── test_generation_hist.png ├── example_casetree_health.png └── example_data_generation_hist.png ├── requirements.txt ├── add_to_docs.txt ├── .gitignore ├── epipy ├── test │ ├── test_data_generator.py │ ├── test_epicurve_plot.py │ ├── test_casetree.py │ ├── test_basics.py │ └── test_analyses.py ├── __init__.py ├── stripe_plot.py ├── rolling_proportion.py ├── data_generator.py ├── epicurve.py ├── or_plot.py ├── checkerboard.py ├── basics.py ├── case_tree.py └── analyses.py ├── MANIFEST ├── setup.py ├── docs ├── FAQ.md ├── examples_0.0.3.py └── examples_0.0.2.py └── README.md /setup.cfg: -------------------------------------------------------------------------------- 1 | [metadata] 2 | description-file = README.md 3 | -------------------------------------------------------------------------------- /requirements-tdd.txt: -------------------------------------------------------------------------------- 1 | pytest==2.6.4 2 | seaborn==0.5.1 3 | 4 | -------------------------------------------------------------------------------- /figs/r0_hist.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cmrivers/epipy/HEAD/figs/r0_hist.png -------------------------------------------------------------------------------- /figs/day_epicurve.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cmrivers/epipy/HEAD/figs/day_epicurve.png -------------------------------------------------------------------------------- /figs/mers_r0_hist.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cmrivers/epipy/HEAD/figs/mers_r0_hist.png -------------------------------------------------------------------------------- /figs/MERS_casetree.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cmrivers/epipy/HEAD/figs/MERS_casetree.png -------------------------------------------------------------------------------- /figs/generation_hist.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cmrivers/epipy/HEAD/figs/generation_hist.png -------------------------------------------------------------------------------- /figs/mers_casetree.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cmrivers/epipy/HEAD/figs/mers_casetree.png -------------------------------------------------------------------------------- /figs/month_epicurve.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cmrivers/epipy/HEAD/figs/month_epicurve.png -------------------------------------------------------------------------------- /figs/steppe_exmple.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cmrivers/epipy/HEAD/figs/steppe_exmple.png -------------------------------------------------------------------------------- /figs/year_epicurve.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cmrivers/epipy/HEAD/figs/year_epicurve.png -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | networkx==1.8.1 2 | numpy==1.8.0 3 | pandas==0.13.1 4 | scipy==0.13.3 5 | 6 | -------------------------------------------------------------------------------- /figs/example_case_tree.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cmrivers/epipy/HEAD/figs/example_case_tree.png -------------------------------------------------------------------------------- /figs/oddsratio_example.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cmrivers/epipy/HEAD/figs/oddsratio_example.png -------------------------------------------------------------------------------- /figs/test_checkerboard.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cmrivers/epipy/HEAD/figs/test_checkerboard.png -------------------------------------------------------------------------------- /figs/mers_generation_hist.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cmrivers/epipy/HEAD/figs/mers_generation_hist.png -------------------------------------------------------------------------------- /figs/mers_sex_generation.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cmrivers/epipy/HEAD/figs/mers_sex_generation.png -------------------------------------------------------------------------------- /figs/rollingprop_example.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cmrivers/epipy/HEAD/figs/rollingprop_example.png -------------------------------------------------------------------------------- /figs/test_generation_hist.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cmrivers/epipy/HEAD/figs/test_generation_hist.png -------------------------------------------------------------------------------- /figs/example_casetree_health.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cmrivers/epipy/HEAD/figs/example_casetree_health.png -------------------------------------------------------------------------------- /figs/example_data_generation_hist.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cmrivers/epipy/HEAD/figs/example_data_generation_hist.png -------------------------------------------------------------------------------- /add_to_docs.txt: -------------------------------------------------------------------------------- 1 | Before next release, add the following to the online documentation: 2 | generate_data() 3 | diagnostic_accuracy() 4 | kappa_agreement() 5 | attributable_risk() 6 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | *.pyc 2 | website/ 3 | build/ 4 | epipy/htmlcov/ 5 | epipy/.coverage 6 | epipy/_pycache_ 7 | epipy/stset 8 | epipy/example_data 9 | .DS_Store 10 | # packages 11 | *.egg 12 | *.egg-info 13 | dist 14 | build 15 | eggs 16 | parts 17 | bin 18 | 19 | 20 | # ipython 21 | .ipynb_checkpoints 22 | -------------------------------------------------------------------------------- /epipy/test/test_data_generator.py: -------------------------------------------------------------------------------- 1 | #! /usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | 4 | import numpy as np 5 | import pandas as pd 6 | import pytest 7 | import data_generator 8 | 9 | 10 | def test_generate_example_data(): 11 | data = data_generator.generate_example_data(cluster_size=5, outbreak_len=180, 12 | clusters=10, gen_time=5, attribute='health') 13 | 14 | assert len(data.Cluster.unique()) == 10 15 | 16 | -------------------------------------------------------------------------------- /MANIFEST: -------------------------------------------------------------------------------- 1 | # file GENERATED by distutils, do NOT edit 2 | setup.cfg 3 | setup.py 4 | epipy/__init__.py 5 | epipy/analyses.py 6 | epipy/basics.py 7 | epipy/case_tree.py 8 | epipy/checkerboard.py 9 | epipy/data_generator.py 10 | epipy/epicurve.py 11 | epipy/test_analyses.py 12 | epipy/test_basics.py 13 | epipy/test_casetree.py 14 | epipy/test_data_generator.py 15 | epipy/test_epicurve_plot.py 16 | epipy/data/example_data.csv 17 | epipy/data/mers_line_list.csv 18 | -------------------------------------------------------------------------------- /epipy/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | # tools 3 | from .data_generator import generate_example_data 4 | from .basics import date_convert, group_clusters, cluster_builder 5 | from .analyses import generation_analysis, reproduction_number, create_2x2 6 | from .analyses import analyze_2x2, odds_ratio, relative_risk, chi2 7 | from .analyses import summary 8 | from .rolling_proportion import rolling_proportion 9 | 10 | # plotting 11 | from .case_tree import build_graph, case_tree_plot 12 | from .epicurve import epicurve_plot 13 | from .checkerboard import checkerboard_plot 14 | from .or_plot import or_plot 15 | from .stripe_plot import stripe_plot 16 | 17 | def get_data(fname): 18 | """Returns pandas dataframe of a line listing. 19 | Choices are 'example_data' and 'mers_line_list' 20 | Example_data is fake data. Mers_line_list is of the MERS-CoV outbreak 21 | of 2012-2014. 22 | """ 23 | import os 24 | import pandas as pd 25 | 26 | this_dir, this_filename = os.path.split(__file__) 27 | DATA_PATH = os.path.join(this_dir, "data", fname+".csv") 28 | 29 | data = pd.read_csv(DATA_PATH) 30 | 31 | return data 32 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from distutils.core import setup 2 | 3 | setup( 4 | name = "epipy", 5 | version = "0.0.2.1", 6 | author = "Caitlin Rivers", 7 | author_email = "caitlin.rivers@gmail.com", 8 | description = "Python tools for epidemiology.", 9 | url = 'http://github.com/cmrivers/epipy', 10 | download_url = 'https://github.com/cmrivers/epipy/tarball/0.0.2.1', 11 | #install_requires = ['Numpy >= 1.6.2', 12 | # 'Matplotlib >=1.2.0', 13 | # 'Networkx >=1.6.0', 14 | # 'Pandas >= 0.12.0', 15 | # 'Scipy >= 0.13'], 16 | license = "MIT", 17 | keywords = "epidemiology", 18 | packages = ['epipy'], 19 | include_package_data=True, 20 | package_data={'epipy': ['data/*.csv']}, 21 | scripts = ['epipy/basics.py', 22 | 'epipy/case_tree.py', 23 | 'epipy/checkerboard.py', 24 | 'epipy/data_generator.py', 25 | 'epipy/epicurve.py', 26 | 'epipy/analyses.py', 27 | 'epipy/or_plot.py'], 28 | long_description='README.md', 29 | classifiers=[ 30 | "Development Status :: Alpha", 31 | "Programming Language :: Python :: 3.6", 32 | "Natural Language :: English", 33 | 'Topic :: Scientific/Engineering', 34 | 'Topic :: Scientific/Engineering :: Mathematics'], 35 | ) 36 | -------------------------------------------------------------------------------- /epipy/test/test_epicurve_plot.py: -------------------------------------------------------------------------------- 1 | #! /usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | 4 | import pandas as pd 5 | from random import sample 6 | import epicurve 7 | 8 | def _test_data(): 9 | times = pd.date_range('12/1/2013', periods=60, freq='d') 10 | times = times.to_datetime() 11 | _data = [(0, 'ClusterA', 0, 0), 12 | (1, 'ClusterB', 1, 1), 13 | (2, 'ClusterA', 0, 0), 14 | (3, 'ClusterA', 0, 2), 15 | (4, 'ClusterB', 1, 1), 16 | (5, 'ClusterB', 1, 4)] 17 | df = pd.DataFrame(_data, columns=['case_id', 'cluster', 'index_node', 'source_node']) 18 | df['pltdate'] = sample(times, len(_data)) 19 | 20 | return df 21 | 22 | 23 | def test_epicurve_plot_month(): 24 | data = _test_data() 25 | curve, fig, ax = epicurve.epicurve_plot(data, 'pltdate', 'm') 26 | 27 | assert len(curve) == 2 28 | assert curve['count'].sum() == 6 29 | 30 | 31 | def test_epicurve_plot_day(): 32 | data = _test_data() 33 | curve, fig, ax = epicurve.epicurve_plot(data, 'pltdate', 'd') 34 | 35 | assert len(curve) == 6 36 | assert curve['count'].sum() == 6 37 | 38 | 39 | def test_epicurve_plot_year(): 40 | data = _test_data() 41 | curve, fig, ax = epicurve.epicurve_plot(data, 'pltdate', 'y') 42 | 43 | assert len(curve) == 2 44 | assert curve['count'].sum() == 6 45 | 46 | 47 | 48 | 49 | 50 | -------------------------------------------------------------------------------- /epipy/test/test_casetree.py: -------------------------------------------------------------------------------- 1 | #! /usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | 4 | import numpy as np 5 | import pandas as pd 6 | from datetime import datetime, timedelta 7 | import pytest 8 | import case_tree 9 | 10 | def _test_data(): 11 | times = pd.date_range('1/1/2013', periods=6, freq='d') 12 | times = times.to_datetime() 13 | _data = [(0, 'ClusterA', 0, 0), 14 | (1, 'ClusterB', 1, 1), 15 | (2, 'ClusterA', 0, 0), 16 | (3, 'ClusterA', 0, 2), 17 | (4, 'ClusterB', 1, 1), 18 | (5, 'ClusterB', 1, 4)] 19 | df = pd.DataFrame(_data, columns=['case_id', 'cluster', 'index_node', 'source_node']) 20 | df['pltdate'] = times 21 | 22 | return df 23 | 24 | 25 | def test_build_graph_graph(): 26 | data = _test_data() 27 | G = case_tree.build_graph(data, 'cluster', 'case_id', 'pltdate', 'cluster', 1, 1) 28 | 29 | assert len(G.node) == 6 30 | edges = [(0, 0), (1, 1), (0, 2), (2, 3), (1, 4), (4, 5)] 31 | assert len(G.edges()) == len(edges) 32 | for tup in G.edges(): 33 | assert tup in edges 34 | 35 | 36 | def test_build_graph_generation(): 37 | data = _test_data() 38 | G = case_tree.build_graph(data, 'cluster', 'case_id', 'pltdate', 'cluster', 1, 1) 39 | 40 | assert G.node[0]['generation'] == 0 41 | assert G.node[1]['generation'] == 0 42 | assert G.node[2]['generation'] == 1 43 | assert G.node[3]['generation'] == 2 44 | assert G.node[4]['generation'] == 1 45 | assert G.node[5]['generation'] == 2 46 | 47 | 48 | 49 | 50 | -------------------------------------------------------------------------------- /docs/FAQ.md: -------------------------------------------------------------------------------- 1 | Updated July 2017 2 | 3 | # FAQ 4 | ----- 5 | ###What is epipy? 6 | 7 | Epipy is a python package for epidemiology. It contains 8 | tools for analyzing and visualizing epidemiology data. 9 | 10 | ###What is a case tree plot? 11 | 12 | Case tree plots are primarily a visualization of zoonotic disease transmission. 13 | However, they can also be used to visualizing environmentally-acquired 14 | dieases, or anything that emerges multiple times, is passed from person 15 | to person, and then dies out. 16 | 17 | ###How do I read a case tree plot? 18 | 19 | The x axis is time of illness onset or diagnosis, and the y axis is 20 | generation. Nodes at generation 0 are known as index nodes. 21 | In the case of a zoonotic disease, the index node is a human case 22 | acquired from an animal source. If that human were to pass 23 | the disease to two other humans, those two subsequent cases are both 24 | generation 1. 25 | 26 | The meaning of the color of the node varies based on the node attribute. 27 | In many cases, color just signifies membership to a particular transmission cluster. 28 | However, it could also represent health status (e.g. alive, dead), the sex of the patient, etc. 29 | 30 | ###What is a checkerboard plot? 31 | 32 | Checkerboard plots display similar data as a case tree plot, but instead 33 | of a network it shows a simple time series for each human to human cluster. 34 | It does not attempt to determine the structure of the transmission network. 35 | The number in the center of each check is the case id that corresponds 36 | to the the line listing. 37 | 38 | ###I have a question/complaint/compliment. 39 | Contact me at caitlinrivers@gmail.com, or @cmyeaton. Feel free to contribute! 40 | -------------------------------------------------------------------------------- /epipy/stripe_plot.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Created on Sun Apr 5 09:02:11 2015 4 | 5 | @author: caitlin 6 | """ 7 | import pandas as pd 8 | import matplotlib.pyplot as plt 9 | import matplotlib.lines as mlines 10 | import numpy as np 11 | import seaborn as sns 12 | 13 | # %% 14 | def stripe_plot(df, yticks, date1, date2, color_col, color_dict=False, legend=True, fig=None, ax=None): 15 | """ 16 | df = pandas dataframe of line list 17 | yticks = column with ytick labels, e.g. case number 18 | date1 = start date, e.g. onset date 19 | date2 = end date, e.g. death date 20 | color = column with color identifiers, e.g. patient sex or outcome 21 | color_dict = optional dictionary with color categories as keys, and colors as values 22 | legend = boolean, optional 23 | ------------ 24 | returns fig, ax 25 | ------------ 26 | Example useage: 27 | fig, ax = stripe_plot(cases, 'case_id', 'onset_date', 'combined_outcome_date', 'categorical_outcome') 28 | """ 29 | if fig == None and ax == None: 30 | fig, ax = plt.subplots() 31 | ax.set_aspect('auto') 32 | fig.autofmt_xdate() 33 | 34 | ax.xaxis_date() 35 | ax.set_ylim(-1, len(df)) 36 | 37 | 38 | if color_dict == False: 39 | color_keys = df[color_col].unique() 40 | color_tuple = sns.color_palette('deep', len(color_keys)) 41 | color_dict = dict(zip(color_keys, color_tuple)) 42 | 43 | 44 | if legend == True: 45 | leg_items = [] 46 | for k, v in color_dict.iteritems(): 47 | item = mlines.Line2D([], [], color=v, marker='o', label=k) 48 | leg_items.append(item) 49 | 50 | ax.legend(handles=leg_items, loc='best') 51 | 52 | 53 | counter = 0 54 | for ix in df[date1].order(ascending=False).index: 55 | x1 = df.xs(ix)[date1] 56 | x2 = df.xs(ix)[date2] 57 | y1 = counter 58 | 59 | col = color_dict[df.xs(ix)[color_col]] 60 | 61 | ax.scatter(x1, y1, color=col) 62 | ax.scatter(x2, y1, color=col) 63 | plt.fill_between([x1, x2], y1-.04, y1+.04, alpha=.8, color=col) 64 | counter += 1 65 | 66 | 67 | plt.yticks(np.arange(len(df)), df[yticks].values) 68 | 69 | return fig, ax 70 | 71 | 72 | -------------------------------------------------------------------------------- /epipy/rolling_proportion.py: -------------------------------------------------------------------------------- 1 | #! /usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | 4 | import pandas as pd 5 | import numpy as np 6 | import matplotlib.pyplot as plt 7 | 8 | def rolling_proportion(df, date_col, value_col, value, window=30, dropna=True, label=False, fig=None, ax=None, color=None): 9 | """ 10 | Interpolated proportion of binary risk factor over time. 11 | 12 | df = pandas df 13 | date_col = name of column containing dates 14 | value_col = name of column to be tallied 15 | value = value to tally (e.g. 'Male') 16 | window = number of days to include. Default is 30. 17 | dropna = exclude rows where val is NaN. Default is true. False will include those rows. 18 | label = legend label 19 | fig, ax = matplotlib objects 20 | ----- 21 | Returns Series of proportions with date index, fig, and ax. 22 | ----- 23 | Example: 24 | datetime_df.index = df.dates 25 | rolling_proportion(datetime_df.sex, 'Male') 26 | 27 | Note: If you are having trouble, make ensure that your date_col is a datetime. 28 | """ 29 | 30 | df = df[df[date_col].isnull() == False] 31 | df.index = df[date_col] 32 | 33 | if dropna == False: 34 | df = df[value_col].fillna(False) 35 | else: 36 | df = df[df[value_col].isnull() == False] 37 | 38 | df['matches'] = df[value_col] == value 39 | df['matches'] = df['matches'].astype(np.int) 40 | df['ones'] = 1 41 | 42 | prop = pd.DataFrame() 43 | prop['numerator'] = df.matches.groupby(by=df.index).sum() 44 | prop['denom'] = df.ones.groupby(by=df.index).sum() 45 | prop['proportion'] = pd.rolling_sum(prop.numerator, window, 5)/pd.rolling_sum(prop.denom, window, 5) 46 | prop = prop.dropna(how='any') 47 | 48 | ts = pd.date_range(min(prop.index), max(prop.index)) 49 | new_prop = prop['proportion'] 50 | new_prop = new_prop.reindex(ts) 51 | new_prop = new_prop.fillna(method='pad') 52 | 53 | if fig is None and ax is None: 54 | fig, ax = plt.subplots() 55 | 56 | if color is None: 57 | color = 'b' 58 | 59 | ax.xaxis_date() 60 | new_prop.plot(ax=ax, label=label, color=color) 61 | fig.autofmt_xdate() 62 | ax.set_ylim(-0.05, 1.05) 63 | ax.set_xlabel('') 64 | if label != False: 65 | ax.legend() 66 | 67 | return new_prop, fig, ax 68 | -------------------------------------------------------------------------------- /epipy/data_generator.py: -------------------------------------------------------------------------------- 1 | #! /usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | 4 | from __future__ import division 5 | import numpy as np 6 | import pandas as pd 7 | import string 8 | 9 | 10 | def _date_choice(ix_date, gen_time): 11 | date_rng = pd.date_range(ix_date, periods=gen_time*2, freq='D') 12 | date = np.random.choice(date_rng, 1) 13 | 14 | return date 15 | 16 | 17 | def generate_example_data(cluster_size, outbreak_len, clusters, gen_time): 18 | """ 19 | Generates example outbreak data 20 | 21 | PARAMETERS 22 | ------------------------------ 23 | cluster_size = mean number of cases in cluster. Build in sd of 2 24 | outbreak_len = duration of outbreak in days 25 | clusters = number of clusters to begenerated 26 | gen_time = time between cases in a cluster 27 | attribute = case attribute. Options are 'sex' (returns M, F) and 28 | 'health' (returns asymptomatic, alive, critical, dead) 29 | 30 | RETURNS 31 | ------------------------------ 32 | pandas dataframe with columns ['caseid', 'date', 'cluster', 'sex', 'health', 'exposure'] 33 | 34 | """ 35 | line_list = [] 36 | used = [] 37 | for i in range(clusters): 38 | cluster_letter = np.random.choice([i for i in string.ascii_uppercase if i not in used])[0] 39 | cluster_name = 'Cluster' + cluster_letter 40 | used.append(cluster_letter) 41 | 42 | ix_rng = pd.date_range('1/1/2014', periods=outbreak_len, freq='D') 43 | ix_date = np.random.choice(ix_rng, size=1) 44 | 45 | rng = int(np.random.normal(cluster_size, 1, 1)) 46 | if rng < 2: 47 | rng += 1 48 | 49 | for n in range(rng): 50 | date = _date_choice(ix_date[0], gen_time)[0] 51 | 52 | dates = [ix_date[0]] 53 | for n in range(rng): 54 | date = _date_choice(dates[-1], gen_time)[0] 55 | dates.append(date) 56 | 57 | attr1 = np.random.choice(['Male', 'Female'], size=1)[0] 58 | attr2 = np.random.choice(['asymptomatic', 'alive', 'critical', 'dead'], size=1)[0] 59 | attr3 = np.random.choice(['exposed', 'notexposed'], size=1)[0] 60 | 61 | line_list.append((len(line_list), date, cluster_name, attr1, attr2, attr3)) 62 | 63 | return pd.DataFrame(line_list, columns=['caseid', 'date', 'cluster', 'sex', 'health', 'exposure']) 64 | -------------------------------------------------------------------------------- /epipy/test/test_basics.py: -------------------------------------------------------------------------------- 1 | #! /usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | 4 | import numpy as np 5 | import pandas as pd 6 | from datetime import datetime, timedelta 7 | import pytest 8 | import basics 9 | 10 | def _test_data(): 11 | _data = [(0, 'ClusterA', '2013-01-01', 'M'), 12 | (1, 'ClusterB', '2013-01-01', 'F'), 13 | (2, 'ClusterA', np.nan, 'M'), 14 | (3, 'ClusterC', '2013-01-04', 'F'), 15 | (4, 'ClusterB', '2013-01-03', 'M'), 16 | (5, 'ClusterB', '2013-01-05', 'M')] 17 | df = pd.DataFrame(_data, columns=['id', 'cluster', 'date', 'sex']) 18 | 19 | return df 20 | 21 | 22 | def test_date_convert_str(): 23 | df = _test_data() 24 | str_val = df.date[0] 25 | dtime = basics.date_convert(str_val) 26 | 27 | assert type(dtime) == datetime 28 | assert dtime == datetime.date(2013, 1, 1) 29 | 30 | 31 | def test_date_convert_nan(): 32 | df = _test_data() 33 | nan_val = df.date[2] 34 | dtime = basics.date_convert(nan_val) 35 | 36 | assert type(dtime) == float 37 | assert np.isnan(dtime) == True 38 | 39 | 40 | def test_date_convert_wrongformat(): 41 | wrong_val = '01-2012-01' 42 | 43 | with pytest.raises(ValueError): 44 | dtime = basics.date_convert(wrong_val) 45 | 46 | 47 | def test_date_convert_wrongformat2(): 48 | wrong_int = 1201201 49 | 50 | with pytest.raises(ValueError): 51 | dtime = basics.date_convert(wrong_int) 52 | 53 | 54 | def test_group_clusters(): 55 | df = _test_data() 56 | groups = basics.group_clusters(df, 'cluster', 'date') 57 | 58 | assert len(groups) == 3 59 | assert groups.groups == {'ClusterA': [0], 'ClusterB': [1, 4, 5], \ 60 | 'ClusterC': [3]} 61 | 62 | 63 | def test_cluster_to_tuple(): 64 | df = _test_data() 65 | df['datetime'] = df['date'].map(basics.date_convert) 66 | 67 | df_out = basics.cluster_builder(df, 'cluster', 'id', 'datetime', \ 68 | 'sex', 2, 1) 69 | df_out = df_out.sort('case_id') 70 | 71 | #sanity check 72 | assert df_out.ix[0]['case_id'] == 0 73 | assert df_out.ix[3]['case_id'] == 3 74 | #index nodes 75 | assert df_out.ix[0]['index_node'] == 0 76 | assert df_out.ix[4]['index_node'] == 1 77 | #source nodes 78 | assert df_out.ix[5]['source_node'] == 4 79 | 80 | 81 | 82 | 83 | -------------------------------------------------------------------------------- /epipy/epicurve.py: -------------------------------------------------------------------------------- 1 | #! /usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | 4 | import pandas as pd 5 | import matplotlib.pyplot as plt 6 | 7 | 8 | def epicurve_plot(df, date_col, freq, fig= None, ax=None, color=None): 9 | ''' 10 | Creates an epicurve (count of new cases over time) 11 | 12 | df = pandas dataframe 13 | date_col = date used to denote case onset or report date 14 | freq = desired plotting frequency. Can be day (d), month (m) or year (y) 15 | fig, ax = fig & ax objects 16 | color = value of bar color to be passed to matplotlib 17 | 18 | RETURNS 19 | ------------------ 20 | curve = Series of dates and counts 21 | fig, ax = matplotlib figure and axis objects 22 | ''' 23 | 24 | if ax == None: 25 | fig, ax = plt.subplots() 26 | 27 | df = df[df[date_col].isnull() == False] 28 | freq = freq.lower()[0] 29 | 30 | ix = pd.date_range(df[date_col].min(), df[date_col].max(), freq='d') 31 | 32 | curve = pd.DataFrame(df[date_col].value_counts().sort_index()) 33 | 34 | if freq == 'm' or freq == 'd': 35 | curve = curve.reindex(ix).fillna(0).resample(freq, how='sum', closed='right') 36 | if freq == 'y': 37 | curve = curve.groupby(curve.index.year).sum() 38 | 39 | curve = curve.rename(columns={date_col:'counts'}) 40 | 41 | fig, ax = _plot(curve, freq, fig, ax, color='#53B8DD') 42 | 43 | return curve, fig, ax 44 | 45 | 46 | def _plot(freq_table, freq, fig, ax, color): 47 | ''' 48 | Plot number of new cases over time 49 | freq_table = frequency table of cases by date, from epicurve() 50 | freq = inherited from epicurve 51 | ''' 52 | 53 | # care about date formatting 54 | if freq == 'd': 55 | ax.xaxis_date() 56 | fig.autofmt_xdate() 57 | ax.bar(freq_table.index.values, freq_table['counts'].values, align='center', color=color) 58 | 59 | elif freq == 'm': 60 | if len(freq_table) < 5: 61 | freq_table.index = freq_table.index.strftime('%b %Y') 62 | freq_table.plot(kind='bar', rot=0, legend=False, color=color) 63 | else: 64 | ax.xaxis_date() 65 | fig.autofmt_xdate() 66 | ax.bar(freq_table.index.values, freq_table['counts'].values, align='center',width=5, color=color) 67 | 68 | elif freq == 'y': 69 | freq_table.plot(kind='bar', rot=0, legend=False, color=color) 70 | 71 | return fig, ax 72 | 73 | 74 | 75 | -------------------------------------------------------------------------------- /epipy/or_plot.py: -------------------------------------------------------------------------------- 1 | # usr/bin/python 2 | # -*- coding: utf-8 -*- 3 | 4 | import pandas as pd 5 | import matplotlib.pyplot as plt 6 | from .analyses import create_2x2, odds_ratio 7 | 8 | 9 | def _plot(_df, fig, ax): 10 | """ 11 | """ 12 | 13 | _df = pd.DataFrame(_df) 14 | df = _df.sort_values('ratio') 15 | df['color'] = 'grey' 16 | df.color[(df.lower > 1) & (df.upper > 1)] = 'blue' 17 | df.color[(df.lower < 1) & (df.upper < 1)] = 'red' 18 | 19 | df.index = range(len(df)) # reset the index to reflect order 20 | 21 | if fig is None and ax is None: 22 | fig, ax = plt.subplots(figsize=(8, 12)) 23 | 24 | ax.set_aspect('auto') 25 | ax.set_xlabel('Odds ratio') 26 | ax.grid(False) 27 | ax.set_ylim(-.5, len(df) - .5) 28 | plt.yticks(df.index) 29 | 30 | ax.scatter(df.ratio, df.index, c=df.color, s=50) 31 | for pos in range(len(df)): 32 | ax.fill_between([df.lower[pos], df.upper[pos]], pos-.01, pos+.01, color='grey', alpha=.3) 33 | 34 | ax.set_yticklabels(df.names) 35 | ax.vlines(x=1, ymin=-.5, ymax=len(df)-.5, colors='grey', linestyles='--') 36 | 37 | return fig, ax 38 | 39 | 40 | def or_plot(df, risk_cols, outcome_col, risk_order, outcome_order, fig=None, ax=None): 41 | """ 42 | df = pandas dataframe of line listing 43 | cols = list of columns to include in analysis 44 | risk_order: dictionary with risk_cols as keys, and a list of values as values, e.g. {'sex':['male', 'female']} 45 | outcome_order: list of values, e.g. ['alive', 'dead'] 46 | 47 | RETURNS 48 | -------- 49 | fig, ax = figure and axis objects 50 | """ 51 | 52 | ratio_df = [] 53 | 54 | for risk_col in risk_cols: 55 | # if risk_order != False: 56 | order = risk_order[risk_col] 57 | 58 | #elif risk_order == False: 59 | # risks = ["{}".format(val) for val in df[risk_col].dropna().unique()] 60 | # outcome_order = ["{}".format(val) for val in df[outcome_col].dropna().unique()] 61 | 62 | _df = df[[outcome_col, risk_col]].dropna(how='any') 63 | 64 | if len(_df[outcome_col].unique()) > 2: 65 | raise Exception('More than two unique values in the outcome') 66 | 67 | if len(_df[risk_col].unique()) > 2: 68 | raise Exception('More than two unique values in {}'.format(risk_col)) 69 | 70 | 71 | table = create_2x2(_df, risk_col, outcome_col, order, outcome_order) 72 | print('{}:'.format(risk_col)) 73 | ratio, or_ci =odds_ratio(table) 74 | print('\n') 75 | 76 | 77 | ratio_df.append({'names': risk_col, 'ratio':ratio, 'lower':or_ci[0], 'upper':or_ci[1]}) 78 | 79 | fig, ax = _plot(ratio_df, fig, ax) 80 | 81 | return fig, ax 82 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | EpiPy 2 | ======== 3 | A Python package for epidemiology. Epipy is a Python package for epidemiology. 4 | It contains tools for analyzing and visualizing epidemiology data. 5 | Epipy can currently produce: 6 | 7 | * stratified summary statistics 8 | * case tree and checkerboard plots 9 | * epicurves 10 | * analysis of case attribute (e.g. sex) by generation 11 | * 2x2 tables with odds ratio and relative risk 12 | * summary of cluster basic reproduction numbers 13 | 14 | Installation 15 | ------------ 16 | The github version of epipy is substantially further along that the pip version. I suggest installing from this repo when possible. 17 | 18 | Install using pip: 19 | 20 | pip install epipy 21 | 22 | Or clone the repository and install using setup.py: 23 | 24 | git clone https://github.com/cmrivers/epipy.git 25 | cd ./epipy 26 | pip install -r requirements.txt 27 | python setup.py install 28 | 29 | EpiPy is in development. Please feel free to contribute. 30 | Contact me at caitlin.rivers@gmail.com or [@cmyeaton](http://twitter.com/cmyeaton) with any questions. 31 | 32 | Learning Python 33 | ------------ 34 | New to Python? I teach a self-paced course on Python for epidemiologists over at [episkills.teachable.com](http://episkills.teachable.com). 35 | 36 | Contributing/Development 37 | ------------ 38 | If you want to contribute in this great project. First fork this repo in github. 39 | 40 | Clone your forked repo in your terminal using the appropriate command: 41 | 42 | git clone https://github.com/your-git-user-name/epipy.git 43 | cd ./epipy 44 | 45 | Add this repo as upstream remote: 46 | 47 | git remote add upstream git@github.com:cmrivers/epipy.git 48 | 49 | We use [gitflow](https://github.com/nvie/gitflow). Follow this [instructions](https://github.com/nvie/gitflow/wiki/Installation) to install. 50 | 51 | git branch master origin/master 52 | git flow init -d 53 | git flow feature start 54 | 55 | For install the tools for TDD use: 56 | 57 | pip install -r requirements.txt 58 | pip install -r requirements-tdd.txt 59 | 60 | To run the test suit use: 61 | 62 | cd ./epipy 63 | py.test test 64 | 65 | Then, do work and commit your changes. After finish your feature with coverage of test, please pull any change that ocurred from the upstream repo. You can use: 66 | 67 | git pull upstream master 68 | 69 | If git fast-forward error is issue then use: 70 | 71 | git rebase upstream/master 72 | 73 | Resolve the merge conflicts that couid exist using: 74 | 75 | git mergetool 76 | git rebase --continue 77 | 78 | After everything is ok then: 79 | 80 | git flow feature publish 81 | 82 | When done, open a pull request to your feature branch. 83 | 84 | 85 | Documentation 86 | ------------ 87 | The docs live at: [cmrivers.github.io/epipy](https://cmrivers.github.io/epipy) 88 | -------------------------------------------------------------------------------- /epipy/checkerboard.py: -------------------------------------------------------------------------------- 1 | #! /usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | ''' 4 | ------------- 5 | * Caitlin Rivers 6 | * [cmrivers@vbi.vt.edu](cmrivers@vbi.vt.edu) 7 | ------------- 8 | I developed checkerboard plots as a companion to case tree plots. A 9 | checkerboard plot shows when cases in a cluster occurred or were 10 | diagnosed, without assuming how they are related. 11 | ''' 12 | import epipy 13 | import matplotlib.pyplot as plt 14 | from datetime import timedelta 15 | from itertools import cycle 16 | import numpy as np 17 | 18 | def checkerboard_plot(df, case_id, cluster_id, date_col, labels='on', alpha=.8, palette=None): 19 | ''' 20 | PARAMETERS 21 | --------------------- 22 | df = pandas dataframe of line listing 23 | case_id = unique identifier of the cases 24 | cluster_id = identifier for each cluster, e.g. FamilyA 25 | date_col = column of onset or report dates 26 | labels = accepts 'on' or 'off'. Labels the first and last case in the cluster with 27 | the unique case identifier. 28 | alpha = transparency of block color 29 | palette = list of colors 30 | 31 | RETURNS 32 | --------------------- 33 | matplotlib figure and axis objects 34 | ''' 35 | clusters = epipy.group_clusters(df, cluster_id, date_col) 36 | 37 | fig, ax = plt.subplots(figsize=(12, 10)) 38 | ax.xaxis_date() 39 | ax.set_aspect('auto') 40 | axprop = ax.axis() 41 | fig.autofmt_xdate() 42 | 43 | grpnames = [key for key, group in clusters if len(group) > 1] 44 | plt.ylim(1, len(grpnames)) 45 | plt.yticks(np.arange(len(grpnames)), grpnames) 46 | 47 | xtog = timedelta(((4*axprop[1]-axprop[0])/axprop[1]), 0, 0) 48 | counter = 0 49 | if palette is None: 50 | cols = cycle([color for i, color in enumerate(plt.rcParams['axes.color_cycle'])]) 51 | else: 52 | cols = cycle(palette) 53 | 54 | for key, group in clusters: 55 | if len(group) > 1: 56 | color = next(cols) 57 | casenums = [int(num) for num in group.index] 58 | iter_casenums = cycle(casenums) 59 | 60 | positions = [] 61 | 62 | for casedate in group[date_col].order(): 63 | curr_casenum = next(iter_casenums) 64 | 65 | x1 = casedate 66 | x2 = casedate + xtog 67 | positions.append(x2) 68 | 69 | y1 = np.array([counter, counter]) 70 | y2 = y1 + 1 71 | 72 | plt.fill_between([x1, x2], y1, y2, color=color, alpha=alpha) 73 | ypos = y1[0] + .5 74 | 75 | if curr_casenum == min(casenums) or curr_casenum == max(casenums): 76 | textspot = x1 + timedelta((x2 - x1).days/2.0, 0, 0) 77 | plt.text(textspot, ypos, curr_casenum, horizontalalignment='center', 78 | verticalalignment='center') 79 | 80 | 81 | counter += 1 82 | 83 | return fig, ax -------------------------------------------------------------------------------- /epipy/basics.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | ''' 4 | ------------- 5 | * Caitlin Rivers 6 | ------------- 7 | ''' 8 | 9 | import pandas as pd 10 | from datetime import datetime, timedelta 11 | import numpy as np 12 | import matplotlib as mpl 13 | 14 | def date_convert(date, str_format='%Y-%m-%d'): 15 | """ Convert dates to datetime object 16 | """ 17 | if type(date) == str: 18 | y = datetime.strptime(date, str_format) 19 | return y 20 | elif np.isnan(date) == True: 21 | y = np.nan 22 | return y 23 | else: 24 | raise ValueError('format of {} not recognized'.format(date)) 25 | 26 | 27 | def group_clusters(df, cluster_id, date_col): 28 | ''' Use pandas to group clusters by cluster identifier 29 | df = pandas dataframe 30 | cluster_id = column that identifies cluster membership, which can 31 | be a basic string like "hospital cluster A" 32 | date_col = onset or report date column 33 | ''' 34 | clusters = df[df[date_col].notnull()] 35 | groups = clusters.groupby(clusters[cluster_id]) 36 | 37 | return groups 38 | 39 | 40 | def cluster_builder(df, cluster_id, case_id, date_col, attr_col, gen_mean, gen_sd): 41 | ''' 42 | Given a line list with dates and info about cluster membership, 43 | this script will estimate the transmission tree of an infectious 44 | disease based on case onset dates. 45 | df = pandas dataframe of line list 46 | cluster_id = col that identifies cluster membership. Can be a 47 | basic string like "hospital cluster A" 48 | case_id = col with unique case identifier 49 | date_col = onset or report date column 50 | attr_col = column that will be used to color nodes based on 51 | attribute, e.g. case severity or gender 52 | gen_mean = generation time mean 53 | gen_sd = generation time standard deviation 54 | returns pandas groupby dataframe 55 | ''' 56 | clusters = group_clusters(df, cluster_id, date_col) 57 | gen_max = timedelta((gen_mean + gen_sd), 0) 58 | 59 | cluster_obj = [] 60 | for key, group in clusters: 61 | row = [tmp[1:4] for tmp in group[[case_id, date_col, 62 | attr_col]].sort(date_col, ).itertuples()] 63 | cluster_obj.append(row) 64 | 65 | network = [] 66 | for cluster in cluster_obj: 67 | #reverse dates, last case first 68 | cluster = np.array(cluster[::-1]) 69 | ids = cluster[:, 0] 70 | dates = cluster[:, 1] 71 | colors = cluster[:, 2] 72 | 73 | index_node = ids[-1] 74 | source_nodes = [] 75 | for i, (date, idx) in enumerate(zip(dates, ids)): 76 | start_date = date - gen_max 77 | start_node = ids[dates >= start_date][-1] 78 | 79 | if start_node == idx and idx != index_node: 80 | start_node = ids[i+1] 81 | 82 | source_nodes.append(start_node) 83 | 84 | for i in range(len(ids)): 85 | result = (ids[i], colors[i], index_node, source_nodes[i], dates[i]) 86 | network.append(result) 87 | 88 | df_out = pd.DataFrame(network, columns=['case_id', attr_col, 'index_node', 'source_node', 'time']) 89 | df_out.time = pd.to_datetime(df_out.time) 90 | 91 | df_out[['case_id', 'source_node', 'index_node']] = df_out[['case_id', 'source_node', 'index_node']].astype('int') 92 | df_out['pltdate'] = [mpl.dates.date2num(i) for i in df_out.time] 93 | df_out.index = df_out.case_id 94 | df_out = df_out.sort('pltdate') 95 | 96 | return df_out -------------------------------------------------------------------------------- /docs/examples_0.0.3.py: -------------------------------------------------------------------------------- 1 | #! /usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | ''' 4 | ------------- 5 | * Caitlin Rivers 6 | * [cmrivers@vbi.vt.edu](cmrivers@vbi.vt.edu) 7 | ------------- 8 | ''' 9 | import pandas as pd 10 | import matplotlib.pyplot as plt 11 | import numpy as np 12 | import epipy 13 | import os 14 | 15 | try: 16 | from mpltools import style, layout 17 | style.use('ggplot') 18 | layout.use('ggplot') 19 | except: 20 | pass 21 | 22 | 23 | ################################# 24 | # TEST DATA EXAMPLE # 25 | ################################# 26 | 27 | example_df = epipy.generate_example_data(cluster_size=6, outbreak_len=180, clusters=8, 28 | gen_time=5, attribute='health') 29 | fig, ax = epipy.case_tree_plot(example_df, cluster_id='Cluster', case_id='ID', date_col='Date', color='health', gen_mean=5, gen_sd=1 ) 30 | ax.set_title('Example outbreak data') 31 | ax.set_ylabel('Generations') 32 | fig.show() 33 | 34 | 35 | # Checkerboard plot 36 | fig, ax = epipy.checkerboard_plot(example_df, 'ID', 'Cluster', 'Date') 37 | ax.set_title("Example outbreak data") 38 | fig.show() 39 | 40 | ############################ 41 | ## MERS-CoV DATA EXAMPLE ### 42 | ############################ 43 | 44 | mers_df = epipy.get_data('mers_line_list') 45 | #you can also get synthetic data using epipy.get_data('example_data') 46 | 47 | # Data cleaning 48 | mers_df['onset_date'] = mers_df['Approx onset date'].map(epipy.date_convert) 49 | mers_df['report_date'] = mers_df['Approx reporting date'].map(epipy.date_convert) 50 | mers_df['dates'] = mers_df['onset_date'].combine_first(mers_df['report_date']) 51 | 52 | # Case tree plot 53 | fig, ax = epipy.case_tree_plot(mers_df, cluster_id='Cluster ID', \ 54 | case_id='Case #', date_col='dates', gen_mean = 5, \ 55 | gen_sd = 4, color='Health status') 56 | ax.set_title('Human clusters of MERS-CoV') 57 | fig.show() 58 | 59 | # Checkerboard plot 60 | fig, ax = epipy.checkerboard_plot(mers_df, 'Case #', 'Cluster ID', 'dates') 61 | ax.set_title("Human clusters of MERS-CoV") 62 | 63 | ################# 64 | ### EPICURVES ### 65 | ################# 66 | 67 | # Daily epicurve of MERS 68 | plt.figure() 69 | curve, fig, ax = epipy.epicurve_plot(mers_df, date_col='dates', freq='day') 70 | plt.title('Approximate onset or report date'); 71 | 72 | # Yearly epicurve of MERS 73 | plt.figure() 74 | epipy.epicurve_plot(mers_df, 'dates', freq='y') 75 | plt.title('Approximate onset or report date') 76 | 77 | # Monthly epicurve of MERS 78 | plt.figure() 79 | curve, fig, ax = epipy.epicurve_plot(mers_df, 'dates', freq='month') 80 | plt.title('Approximate onset or report date of MERS cases') 81 | fig.show() 82 | 83 | ################# 84 | ### ANALYSES #### 85 | ################# 86 | 87 | # We'll use the MERS data we worked with above 88 | # For this we'll need to build out the graph 89 | mers_G = epipy.build_graph(mers_df, cluster_id='Cluster ID', case_id='Case #', 90 | date_col='dates', color='Health status', gen_mean=5, gen_sd=4) 91 | 92 | # Analyze attribute by generation 93 | fig, ax, table = epipy.generation_analysis(mers_G, attribute='Health status', plot=True) 94 | fig.show() 95 | 96 | # Basic reproduction numbers 97 | R, fig, ax = epipy.reproduction_number(mers_G, index_cases=True, plot=True) 98 | print 'R0 median: {}'.format(R.median()) # the series object returned can be manipulated further 99 | fig.show() 100 | 101 | #2X2 table 102 | mers_df['condensed_health'] = mers_df['Health status'].replace(['Critical', 'Alive', 'Asymptomatic', 'Mild', 'Recovered', 'Reocvered'], 'Alive') 103 | table = epipy.create_2x2(mers_df, 'Sex', 'condensed_health', ['M', 'F'], ['Dead', 'Alive']) 104 | epipy.analyze_2x2(table) 105 | -------------------------------------------------------------------------------- /docs/examples_0.0.2.py: -------------------------------------------------------------------------------- 1 | #! /usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | ''' 4 | ------------- 5 | Examples for epipy versions 0.0.1 and 0.0.2X 6 | 7 | 8 | * Caitlin Rivers 9 | * [cmrivers@vbi.vt.edu](cmrivers@vbi.vt.edu) 10 | ------------- 11 | ''' 12 | import pandas as pd 13 | import matplotlib.pyplot as plt 14 | import numpy as np 15 | import epipy 16 | import os 17 | 18 | try: 19 | from mpltools import style, layout 20 | style.use('ggplot') 21 | layout.use('ggplot') 22 | except: 23 | pass 24 | 25 | 26 | ################################# 27 | # TEST DATA EXAMPLE # 28 | ################################# 29 | 30 | example_df = epipy.generate_example_data(cluster_size=6, outbreak_len=180, clusters=8, 31 | gen_time=5, attribute='health') 32 | fig, ax = epipy.case_tree_plot(example_df, cluster_id='Cluster', case_id='ID', date_col='Date', color='health', gen_mean=5, gen_sd=1 ) 33 | ax.set_title('Example outbreak data') 34 | ax.set_ylabel('Generations') 35 | fig.show() 36 | 37 | 38 | # Checkerboard plot 39 | fig, ax = epipy.checkerboard_plot(example_df, 'ID', 'Cluster', 'Date') 40 | ax.set_title("Example outbreak data") 41 | fig.show() 42 | 43 | ############################ 44 | ## MERS-CoV DATA EXAMPLE ### 45 | ############################ 46 | 47 | mers_df = epipy.get_data('mers_line_list') 48 | #you can also get synthetic data using epipy.get_data('example_data') 49 | 50 | # Data cleaning 51 | mers_df['onset_date'] = mers_df['Approx onset date'].map(epipy.date_convert) 52 | mers_df['report_date'] = mers_df['Approx reporting date'].map(epipy.date_convert) 53 | mers_df['dates'] = mers_df['onset_date'].combine_first(mers_df['report_date']) 54 | 55 | # Case tree plot 56 | fig, ax = epipy.case_tree_plot(mers_df, cluster_id='Cluster ID', \ 57 | case_id='Case #', date_col='dates', gen_mean = 5, \ 58 | gen_sd = 4, color='Health status') 59 | ax.set_title('Human clusters of MERS-CoV') 60 | fig.show() 61 | 62 | # Checkerboard plot 63 | fig, ax = epipy.checkerboard_plot(mers_df, 'Case #', 'Cluster ID', 'dates') 64 | ax.set_title("Human clusters of MERS-CoV") 65 | 66 | ################# 67 | ### EPICURVES ### 68 | ################# 69 | 70 | # Daily epicurve of MERS 71 | plt.figure() 72 | curve, fig, ax = epipy.epicurve_plot(mers_df, date_col='dates', freq='day') 73 | plt.title('Approximate onset or report date'); 74 | 75 | # Yearly epicurve of MERS 76 | plt.figure() 77 | epipy.epicurve_plot(mers_df, 'dates', freq='y') 78 | plt.title('Approximate onset or report date') 79 | 80 | # Monthly epicurve of MERS 81 | plt.figure() 82 | curve, fig, ax = epipy.epicurve_plot(mers_df, 'dates', freq='month') 83 | plt.title('Approximate onset or report date of MERS cases') 84 | fig.show() 85 | 86 | ################# 87 | ### ANALYSES #### 88 | ################# 89 | 90 | # We'll use the MERS data we worked with above 91 | # For this we'll need to build out the graph 92 | mers_G = epipy.build_graph(mers_df, cluster_id='Cluster ID', case_id='Case #', 93 | date_col='dates', color='Health status', gen_mean=5, gen_sd=4) 94 | 95 | # Analyze attribute by generation 96 | fig, ax, table = epipy.generation_analysis(mers_G, attribute='Health status', plot=True) 97 | fig.show() 98 | 99 | # Basic reproduction numbers 100 | R, fig, ax = epipy.reproduction_number(mers_G, index_cases=True, plot=True) 101 | print 'R0 median: {}'.format(R.median()) # the series object returned can be manipulated further 102 | fig.show() 103 | 104 | #2X2 table 105 | mers_df['condensed_health'] = mers_df['Health status'].replace(['Critical', 'Alive', 'Asymptomatic', 'Mild', 'Recovered', 'Reocvered'], 'Alive') 106 | table = epipy.create_2x2(mers_df, 'Sex', 'condensed_health', ['M', 'F'], ['Dead', 'Alive']) 107 | epipy.analyze_2x2(table) 108 | -------------------------------------------------------------------------------- /epipy/test/test_analyses.py: -------------------------------------------------------------------------------- 1 | #! /usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | 4 | from __future__ import division 5 | import numpy as np 6 | import pandas as pd 7 | import networkx as nx 8 | import pytest 9 | import analyses 10 | 11 | def test_ordered_table_list(): 12 | table = [(0, 1), 13 | (2, 3)] 14 | 15 | a, b, c, d = analyses._ordered_table(table) 16 | assert a == 0 17 | assert b == 1 18 | assert c == 2 19 | assert d == 3 20 | 21 | 22 | def test_ordered_table_numpy(): 23 | table = [(0, 1), 24 | (2, 3)] 25 | table = np.array(table) 26 | 27 | a, b, c, d = analyses._ordered_table(table) 28 | assert a == 0 29 | assert b == 1 30 | assert c == 2 31 | assert d == 3 32 | 33 | 34 | def test_ordered_table_DataFrame(): 35 | table = [(0, 1), 36 | (2, 3)] 37 | table = pd.DataFrame(table) 38 | 39 | a, b, c, d = analyses._ordered_table(table) 40 | assert a == 0 41 | assert b == 1 42 | assert c == 2 43 | assert d == 3 44 | 45 | 46 | def test_ordered_table_typeError(): 47 | table = [(0, 1), 48 | (2, 3)] 49 | table = np.matrix(table) 50 | 51 | with pytest.raises(TypeError): 52 | a, b, c, d = analyses._ordered_table(table) 53 | 54 | 55 | 56 | def test_odds_ratio(): 57 | table = [(1, 2), 58 | (3, 4)] 59 | 60 | ratio, or_ci = analyses.odds_ratio(table) 61 | 62 | assert np.allclose(ratio, .6667, atol=.01) 63 | assert np.allclose(or_ci, (0.03939, 11.28), atol=.01) 64 | 65 | 66 | def test_relative_risk(): 67 | table = [(1, 2), 68 | (3, 4)] 69 | 70 | rr, rr_ci = analyses.relative_risk(table) 71 | 72 | assert np.allclose(rr, 0.7778, atol=.01) 73 | assert np.allclose(rr_ci, (0.1267, 4.774), atol=.01) 74 | 75 | 76 | def test_chi2(): 77 | table = [(1, 2), 78 | (3, 4)] 79 | 80 | chi2, p, dof, expected = analyses.chi2(table) 81 | 82 | assert np.allclose(chi2, 0.1786, atol=.01) 83 | 84 | 85 | def test_AR(): 86 | table = [(1, 2), 87 | (3, 4)] 88 | 89 | ar, arp, par, parp = analyses.attributable_risk(table) 90 | 91 | assert np.allclose(ar, -.09524, atol=.01) 92 | assert np.allclose(arp, -28.5714, atol=.01) 93 | assert np.allclose(par, -.02857, atol=.01) 94 | assert np.allclose(parp, -7.143, atol=.01) 95 | 96 | 97 | def test_create2x2(): 98 | df = pd.DataFrame({'Exposed':['Y', 'Y', 'N', 'Y'], \ 99 | 'Sick':['Y', 'N', 'N', 'Y']}) 100 | table = analyses.create_2x2(df, 'Exposed', 'Sick', ['Y', 'N'], \ 101 | ['Y', 'N']) 102 | 103 | assert table.ix[0][0] == 2 104 | assert table.ix[0][1] == 1 105 | assert table.ix[1][0] == 0 106 | assert table.ix[1][1] == 1 107 | 108 | 109 | def test_2x2_errorRaises(): 110 | df = pd.DataFrame({'Exposed':['Y', 'Y', 'N', 'Y'], \ 111 | 'Sick':['Y', 'N', 'N', 'Y']}) 112 | 113 | with pytest.raises(TypeError): 114 | table = analyses.create_2x2(df, 'Exposed', 'Sick', ['Y', 'N'], \ 115 | 'Y') 116 | 117 | with pytest.raises(AssertionError): 118 | table = analyses.create_2x2(df, 'Exposed', 'Sick', ['Y', 'N'], \ 119 | ['Y']) 120 | 121 | def _create_graph(): 122 | G = nx.DiGraph() 123 | G.add_nodes_from([3, 4, 5]) 124 | G.node[3]['generation'] = 0 125 | G.node[4]['generation'] = 1 126 | G.node[5]['generation'] = 1 127 | G.node[3]['health'] = 'alive' 128 | G.node[4]['health'] = 'dead' 129 | G.node[5]['health'] = 'alive' 130 | G.add_edges_from([(3, 4), (3, 5)]) 131 | 132 | return G 133 | 134 | 135 | def test_generation_analysis(): 136 | G = _create_graph() 137 | table = analyses.generation_analysis(G, 'health', plot=False) 138 | 139 | assert table.ix[0][0] == 1 140 | assert table.ix[0][1] == 0 141 | assert table.ix[1][0] == 1 142 | assert table.ix[1][1] == 1 143 | 144 | 145 | def test_reproduction_number_index(): 146 | G = _create_graph() 147 | R = analyses.reproduction_number(G, index_cases=True, plot=False) 148 | 149 | assert len(R) == 3 150 | assert R.iget(0) == 2 151 | assert R.iget(1) == 0 152 | assert R.iget(2) == 0 153 | 154 | 155 | def test_reproduction_number_noindex(): 156 | G = _create_graph() 157 | R = analyses.reproduction_number(G, index_cases=False, plot=False) 158 | 159 | assert len(R) == 2 160 | assert R.iget(0) == 0 161 | assert R.iget(1) == 0 162 | 163 | 164 | def test_numeric_summary(): 165 | df = pd.DataFrame({'Age' : [10, 12, 14], 'Group' : ['A', 'B', 'B'] }) 166 | summ = analyses.summary(df.Age) 167 | 168 | assert summ['count'] == 3 169 | assert summ['missing'] == 0 170 | assert summ['min'] == 10 171 | assert summ['median'] == 12 172 | assert summ['mean'] == 12 173 | assert summ['std'] == 2 174 | assert summ['max'] == 14 175 | 176 | 177 | def test_categorical_summary(): 178 | df = pd.DataFrame({'Age' : [10, 12, 14], 'Group' : ['A', 'B', 'B'] }) 179 | summ = analyses.summary(df.Group) 180 | 181 | assert summ.ix[0]['count'] == 2 182 | assert np.allclose(summ.ix[0]['freq'], 2/3, atol=.01) 183 | 184 | 185 | def test_grouped_summary(): 186 | df = pd.DataFrame({'Age' : [10, 12, 14], 'Group' : ['A', 'B', 'B'] }) 187 | summ = analyses.summary(df.Age, df.Group) 188 | 189 | assert len(summ) == 2 190 | assert len(summ.columns) == 7 191 | 192 | 193 | 194 | 195 | -------------------------------------------------------------------------------- /epipy/case_tree.py: -------------------------------------------------------------------------------- 1 | #! /usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | 4 | from __future__ import division 5 | import numpy as np 6 | from .basics import cluster_builder 7 | import matplotlib.pyplot as plt 8 | import networkx as nx 9 | 10 | def build_graph(df, cluster_id, case_id, date_col, color, gen_mean, gen_sd, palette=None): 11 | """ 12 | Generate a directed graph from data on transmission tree. 13 | Node color is determined by node attributes, e.g. case severity or gender. 14 | df = pandas dataframe 15 | """ 16 | 17 | clusters = cluster_builder(df=df, cluster_id=cluster_id, \ 18 | case_id=case_id, date_col=date_col, attr_col=color, \ 19 | gen_mean=gen_mean, gen_sd=gen_sd) 20 | 21 | G = nx.DiGraph() 22 | G.add_nodes_from(clusters['case_id']) 23 | 24 | edgelist = [pair for pair in clusters[['source_node']].dropna().itertuples()] 25 | G.add_edges_from(edgelist) 26 | nx.set_node_attributes(G, 'date', clusters['time'].to_dict()) 27 | nx.set_node_attributes(G, 'pltdate', clusters['pltdate'].to_dict()) 28 | nx.set_node_attributes(G, 'source_node', clusters['source_node'].to_dict()) 29 | nx.set_node_attributes(G, color, clusters[color].to_dict()) 30 | nx.set_node_attributes(G, 'index_node', clusters['index_node'].to_dict()) 31 | G = nx.DiGraph.reverse(G) 32 | 33 | for i in G.nodes(): 34 | G.node[i]['generation'] = _generations(G, i) 35 | 36 | return G 37 | 38 | 39 | 40 | def plot_tree(G, color, node_size, loc, legend, color_palette, fig, ax): 41 | 42 | if ax is None or fig is None: 43 | fig, ax = plt.subplots(figsize=(12, 8)) 44 | 45 | fig.autofmt_xdate() 46 | ax.xaxis_date() 47 | ax.set_aspect('auto') 48 | 49 | coords = _layout(G) 50 | plt.ylim(ymin=-.05, ymax=max([val[1] for val in coords.itervalues()])+1) 51 | 52 | colormap, color_floats = _colors(G, color, color_palette=color_palette) 53 | 54 | if legend == True: 55 | x_val = G.nodes()[0] 56 | lines = [] 57 | 58 | for key, value in colormap.iteritems(): 59 | plt.scatter(G.node[x_val]['pltdate'], value[0], color=value, alpha=0) 60 | line = plt.Line2D(range(1), range(1), color=value, marker='o', markersize=6, alpha=.8, label=key) 61 | lines.append(line) 62 | 63 | ax.legend(lines, [k for k in colormap.iterkeys()], loc=loc) 64 | 65 | nx.draw_networkx(G, ax=ax, with_labels=False, pos=coords, node_color=color_floats, 66 | node_size=node_size, alpha=.8) 67 | 68 | return fig, ax 69 | 70 | 71 | def case_tree_plot(df, cluster_id, case_id, date_col, color, \ 72 | gen_mean, gen_sd, node_size=100, loc='best',\ 73 | legend=True, color_palette=None, fig=None, ax=None): 74 | """ 75 | Plot casetree 76 | df = pandas dataframe, line listing 77 | cluster_id = col that identifies cluster membership. Can be a 78 | basic string like "hospital cluster A" 79 | case_id = col with unique case identifier 80 | date_col = onset or report date column 81 | color = column that will be used to color nodes based on 82 | attribute, e.g. case severity or gender 83 | gen_mean = generation time mean 84 | gen_sd = generation time standard deviation 85 | node_size = on (display node) or off (display edge only). Default is on. 86 | loc = legend location. See matplotlib args. 87 | legend = True for legend, False for no legend 88 | color_palette = dictionary of category:color pairs OR a list of colors 89 | fig, ax = fig ax objects 90 | """ 91 | G = build_graph(df, cluster_id, case_id, date_col, color, \ 92 | gen_mean, gen_sd) 93 | 94 | fig, ax = plot_tree(G, color, node_size, loc, legend, color_palette, fig, ax) 95 | 96 | return G, fig, ax 97 | 98 | 99 | def _colors(G, color, color_palette=None): 100 | """ 101 | Determines colors of the node based on node attribute, 102 | e.g. case severity or gender. 103 | G = networkx object 104 | color = name of node attribute in graph used to assign color 105 | """ 106 | # collect list of unique attributes from graph 107 | if type(color_palette) != dict: 108 | categories = [] 109 | for node in G.nodes(): 110 | categories.append(G.node[node][color]) 111 | 112 | categories = np.unique(categories) 113 | # create color map of attributes and colors 114 | if type(color_palette) == list: 115 | colors = color_palette 116 | else: 117 | colors = [color for i, c in enumerate(plt.rcParams['axes.color_cycle'])] 118 | 119 | color_dict = dict(zip(categories, colors)) 120 | 121 | elif type(color_palette) == dict: 122 | color_dict = color_palette 123 | 124 | color_floats = [] 125 | for node in G.nodes(): 126 | G.node[node]['plot_color'] = color_dict[G.node[node][color]] 127 | color_floats.append(color_dict[G.node[node][color]]) 128 | 129 | 130 | return color_dict, color_floats 131 | 132 | 133 | def _generations(G, node): 134 | """ Recursively determines the generation of the node, e.g. how many 135 | links up the chain of transmission it is. 136 | This value is used as the y coordinate. 137 | G = networkx object 138 | node = node in network 139 | """ 140 | levels = 0 141 | 142 | while node != G.node[node]['source_node']: 143 | node = G.node[node]['source_node'] 144 | levels += 1 145 | 146 | return levels 147 | 148 | 149 | def _layout(G): 150 | """Determine x and y coordinates of each node. 151 | G = networkx object 152 | axprop = matplotlib axis object 153 | """ 154 | np.random.seed(0) # consistent layout between runs(?) 155 | positions = [] 156 | 157 | for i in G.nodes(): 158 | xcord = G.node[i]['pltdate'] 159 | generation = G.node[i]['generation'] 160 | if generation == 0: 161 | ygen = generation 162 | else: 163 | jittery = np.random.uniform(-.2, .2, 1) 164 | ygen = generation + jittery 165 | 166 | positions.append([xcord, ygen]) 167 | 168 | return dict(zip(G, np.array(positions))) 169 | 170 | 171 | -------------------------------------------------------------------------------- /epipy/analyses.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | 4 | from __future__ import division 5 | import numpy as np 6 | from scipy.stats import chi2_contingency 7 | import pandas as pd 8 | import matplotlib.pyplot as plt 9 | 10 | """ 11 | Author: Caitlin Rivers 12 | Analysis functions for package epipy. 13 | """ 14 | def _get_table_labels(table): 15 | """ 16 | Returns classic a, b, c, d labels for contingency table calcs. 17 | """ 18 | a = table[0][0] 19 | b = table[0][1] 20 | c = table[1][0] 21 | d = table[1][1] 22 | 23 | return a, b, c, d 24 | 25 | 26 | def _ordered_table(table): 27 | """ 28 | Determine type of table input. Find classic a, b, c, d labels 29 | for contigency table calculations. 30 | """ 31 | if type(table) is list: 32 | a, b, c, d = _get_table_labels(table) 33 | elif type(table) is pd.core.frame.DataFrame: 34 | a, b, c, d = _get_table_labels(table.values) 35 | elif type(table) is np.ndarray: 36 | a, b, c, d = _get_table_labels(table) 37 | else: 38 | raise TypeError('table format not recognized') 39 | 40 | return a, b, c, d 41 | 42 | 43 | def _conf_interval(ratio, std_error): 44 | """ 45 | Calculate 95% confidence interval for odds ratio and relative risk. 46 | """ 47 | 48 | _lci = np.log(ratio) - 1.96*std_error 49 | _uci = np.log(ratio) + 1.96*std_error 50 | 51 | lci = round(np.exp(_lci), 2) 52 | uci = round(np.exp(_uci), 2) 53 | 54 | return (lci, uci) 55 | 56 | def _numeric_summary(column): 57 | """ 58 | Finds count, number of missing values, min, median, mean, std, and 59 | max. 60 | See summary() 61 | """ 62 | names = ['count', 'missing', 'min', 'median', 'mean', 'std', 'max'] 63 | _count = len(column) 64 | _miss = _count - len(column.dropna()) 65 | _min = column.min() 66 | _median = column.median() 67 | _mean = column.mean() 68 | _std = column.std() 69 | _max = column.max() 70 | summ = pd.Series([_count, _miss, _min, _median, _mean, _std, _max], index=names) 71 | 72 | return summ 73 | 74 | 75 | def _categorical_summary(column, n=None): 76 | """ 77 | Finds count and frequency of each unique value in the column. 78 | See summary(). 79 | """ 80 | if n is not None: 81 | _count = column.value_counts()[:n] 82 | else: 83 | _count = column.value_counts() 84 | names = ['count', 'freq'] 85 | _freq = column.value_counts(normalize=True)[:n] 86 | summ = pd.DataFrame([_count, _freq], index=names).T 87 | 88 | return summ 89 | 90 | 91 | def _summary_calc(column, by=None): 92 | """ 93 | Calculates approporiate summary statistics based on data type. 94 | PARAMETERS 95 | ---------------------- 96 | column = one column (series) of pandas df 97 | by = optional. stratifies summary statistics by each value in the 98 | column. 99 | 100 | RETURNS 101 | ---------------------- 102 | if column data type is numeric, returns summary statistics 103 | if column data type is an object, returns count and frequency of 104 | top 5 most common values 105 | """ 106 | if column.dtype == 'float64' or column.dtype == 'int64': 107 | coltype = 'numeric' 108 | elif column.dtype == 'object': 109 | coltype = 'object' 110 | 111 | 112 | if by is None: 113 | if coltype == 'numeric': 114 | summ = _numeric_summary(column) 115 | 116 | elif coltype == 'object': 117 | summ = _categorical_summary(column, 5) 118 | 119 | else: 120 | 121 | if coltype == 'numeric': 122 | column_list = [] 123 | 124 | vals = by.dropna().unique() 125 | for value in vals: 126 | subcol = column[by == value] 127 | summcol = _numeric_summary(subcol) 128 | column_list.append(summcol) 129 | 130 | summ = pd.DataFrame(column_list, index=vals) 131 | 132 | elif coltype == 'object': 133 | subcol = column.groupby(by) 134 | summ = _categorical_summary(subcol) 135 | #summ = _summ.sort_values(by=subcol) 136 | 137 | return summ 138 | 139 | 140 | def reproduction_number(G, index_cases=True, plot=True): 141 | """ 142 | Finds each case's basic reproduction number, which is the number of secondary 143 | infections each case produces. 144 | 145 | PARAMETERS 146 | ---------------- 147 | G = networkx object 148 | index_cases = include index nodes, i.e. those at generation 0. Default is True. 149 | Excluding them is useful if you want to calculate the human to human 150 | reproduction number without considering zoonotically acquired cases. 151 | summary = print summary statistics of the case reproduction numbers 152 | plot = create histogram of case reproduction number distribution. 153 | 154 | RETURNS 155 | ---------------- 156 | pandas series of case reproduction numbers and matplotlib figure 157 | and axis objects if plot=True 158 | """ 159 | 160 | if index_cases == True: 161 | R = pd.Series(G.out_degree()) 162 | 163 | elif index_cases == False: 164 | degrees = {} 165 | 166 | for n in G.node: 167 | if G.node[n]['generation'] > 0: 168 | degrees[n] = G.out_degree(n) 169 | R = pd.Series(degrees) 170 | 171 | print('Summary of reproduction numbers') 172 | print(R.describe(), '\n') 173 | 174 | if plot == True: 175 | fig, ax = plt.subplots() 176 | R.hist(ax=ax, alpha=.5) 177 | ax.set_xlabel('Secondary cases') 178 | ax.set_ylabel('Count') 179 | ax.grid(False) 180 | return R, fig, ax 181 | 182 | else: 183 | return R 184 | 185 | 186 | def generation_analysis(G, attribute, plot=True): 187 | """ 188 | Analyzes an attribute, e.g. health status, by generation. 189 | 190 | PARAMETERS 191 | ------------- 192 | G = networkx object 193 | attribute = case attribute for analysis, e.g. health status or sex 194 | table = print cross table of attribute by generation. Default is true. 195 | plot = produce histogram of attribute by generation. Default is true. 196 | 197 | RETURNS 198 | -------------- 199 | matplotlib figure and axis objects 200 | 201 | """ 202 | 203 | gen_df = pd.DataFrame(G.node).T 204 | 205 | print('{} by generation').format(attribute) 206 | table = pd.crosstab(gen_df.generation, gen_df[attribute], margins=True) 207 | print(table, '\n') 208 | 209 | if plot == True: 210 | fig, ax = plt.subplots() 211 | ax.set_aspect('auto') 212 | pd.crosstab(gen_df.generation, gen_df[attribute]).plot(kind='bar', ax=ax, alpha=.5, rot=0) 213 | ax.set_xlabel('Generation') 214 | ax.set_ylabel('Case count') 215 | ax.grid(False) 216 | ax.legend(loc='best'); 217 | return fig, ax, table 218 | else: 219 | return table 220 | 221 | 222 | def create_2x2(df, row, column, row_order, col_order): 223 | """ 224 | 2x2 table of disease and exposure in traditional epi order. 225 | 226 | Table format: 227 | Disease 228 | Exposure YES NO 229 | YES a b 230 | NO c d 231 | 232 | PARAMETERS 233 | ----------------------- 234 | df = pandas dataframe of line listing 235 | row = name of exposure row as string 236 | column = name of outcome column as string 237 | row_order = list of length 2 of row values in yes/no order. 238 | Example: ['Exposed', 'Unexposed'] 239 | col_order = list of length 2 column values in yes/no order. 240 | Example: ['Sick', 'Not sick'] 241 | 242 | RETURNS 243 | ------------------------ 244 | pandas dataframe of 2x2 table. Prints odds ratio and relative risk. 245 | """ 246 | if type(col_order) != list or type(row_order) != list: 247 | raise TypeError('row_order and col_order must each be lists of length 2') 248 | 249 | if len(col_order) != 2 or len(row_order) != 2: 250 | raise AssertionError('row_order and col_order must each be lists of length 2') 251 | 252 | _table = pd.crosstab(df[row], df[column], margins=True).to_dict() 253 | 254 | trow = row_order[0] 255 | brow = row_order[1] 256 | tcol = col_order[0] 257 | bcol = col_order[1] 258 | 259 | table = pd.DataFrame(_table, index=[trow, brow, 'All'], columns=[tcol, bcol, 'All']) 260 | 261 | return table 262 | 263 | 264 | def analyze_2x2(table): 265 | """ 266 | Prints odds ratio, relative risk, and chi square. 267 | See also create_2x2(), odds_ratio(), relative_risk(), and chi2() 268 | 269 | PARAMETERS 270 | -------------------- 271 | 2x2 table as pandas dataframe, numpy array, or list in format [a, b, c, d] 272 | 273 | Table format: 274 | Disease 275 | Exposure YES NO 276 | YES a b 277 | NO c d 278 | 279 | """ 280 | 281 | odds_ratio(table) 282 | relative_risk(table) 283 | attributable_risk(table) 284 | chi2(table) 285 | 286 | 287 | def odds_ratio(table): 288 | """ 289 | Calculates the odds ratio and 95% confidence interval. See also 290 | analyze_2x2() 291 | *Cells in the table with a value of 0 will be replaced with .1 292 | 293 | PARAMETERS 294 | ---------------------- 295 | table = accepts pandas dataframe, numpy array, or list in [a, b, c, d] format. 296 | 297 | RETURNS 298 | ---------------------- 299 | returns and prints odds ratio and tuple of 95% confidence interval 300 | """ 301 | 302 | a, b, c, d = _ordered_table(table) 303 | 304 | ratio = (a*d)/(b*c) 305 | or_se = np.sqrt((1/a)+(1/b)+(1/c)+(1/d)) 306 | or_ci = _conf_interval(ratio, or_se) 307 | print('Odds ratio: {} (95% CI: {})'.format(round(ratio, 2), or_ci)) 308 | 309 | return round(ratio, 2), or_ci 310 | 311 | 312 | 313 | 314 | def relative_risk(table, display=True): 315 | """ 316 | Calculates the relative risk and 95% confidence interval. See also 317 | analyze_2x2(). 318 | *Cells in the table with a value of 0 will be replaced with .1 319 | 320 | PARAMETERS 321 | ---------------------- 322 | table = accepts pandas dataframe, numpy array, or list in [a, b, c, d] format. 323 | 324 | RETURNS 325 | ---------------------- 326 | returns and prints relative risk and tuple of 95% confidence interval 327 | """ 328 | 329 | a, b, c, d = _ordered_table(table) 330 | 331 | rr = (a/(a+b))/(c/(c+d)) 332 | rr_se = np.sqrt(((1/a)+(1/c)) - ((1/(a+b)) + (1/(c+d)))) 333 | rr_ci = _conf_interval(rr, rr_se) 334 | 335 | if display is not False: 336 | print('Relative risk: {} (95% CI: {}-{})\n'.format(round(rr, 2), round(rr_ci[0],2), round(rr_ci[1], 2))) 337 | 338 | return rr, rr_ci 339 | 340 | 341 | def attributable_risk(table): 342 | """ 343 | Calculate the attributable risk, attributable risk percent, 344 | and population attributable risk. 345 | 346 | PARAMETERS 347 | ---------------- 348 | table = 2x2 table. See 2x2_table() 349 | 350 | RETURNS 351 | ---------------- 352 | prints and returns attributable risk (AR), attributable risk percent 353 | (ARP), population attributable risk (PAR) and population attributable 354 | risk percent (PARP). 355 | """ 356 | a, b, c, d = _ordered_table(table) 357 | N = a + b + c + d 358 | 359 | ar = (a/(a+b))-(c/(c+d)) 360 | ar_se = np.sqrt(((a+c)/N)*(1-((a+c)/N))*((1/(a+b))+(1/(c+d)))) 361 | ar_ci = (round(ar-(1.96*ar_se), 2), round(ar+(1.96*ar_se), 2)) 362 | 363 | rr, rci = relative_risk(table, display=False) 364 | arp = 100*((rr-1)/(rr)) 365 | arp_se = (1.96*ar_se)/ar 366 | arp_ci = (round(arp-arp_se, 2), round(arp+arp_se, 3)) 367 | 368 | par = ((a+c)/N) - (c/(c+d)) 369 | parp = 100*(par/(((a+c)/N))) 370 | 371 | print('Attributable risk: {} (95% CI: {})'.format(round(ar, 3), ar_ci)) 372 | print('Attributable risk percent: {}% (95% CI: {})'.format(round(arp, 2), arp_ci)) 373 | print('Population attributable risk: {}'.format(round(par, 3))) 374 | print('Population attributable risk percent: {}% \n'.format(round(parp, 2))) 375 | 376 | return ar, arp, par, parp 377 | 378 | 379 | def chi2(table): 380 | """ 381 | Scipy.stats function to calculate chi square. 382 | PARAMETERS 383 | ---------------------- 384 | table = accepts pandas dataframe or numpy array. See also 385 | analyze_2x2(). 386 | 387 | RETURNS 388 | ---------------------- 389 | returns chi square with yates correction, p value, 390 | degrees of freedom, and array of expected values. 391 | prints chi square and p value 392 | """ 393 | chi2, p, dof, expected = chi2_contingency(table) 394 | print('Chi square: {}'.format(chi2)) 395 | print('p value: {}'.format(p)) 396 | 397 | return chi2, p, dof, expected 398 | 399 | 400 | def summary(data, by=None): 401 | """ 402 | Displays approporiate summary statistics for each column in a line listing. 403 | 404 | PARAMETERS 405 | ---------------------- 406 | data = pandas data frame or series 407 | 408 | RETURNS 409 | ---------------------- 410 | for each column in the dataframe, or for hte series: 411 | - if column data type is numeric, returns summary statistics 412 | - if column data type is non-numeric, returns count and frequency of 413 | top 5 most common values. 414 | 415 | EXAMPLE 416 | ---------------------- 417 | df = pd.DataFrame({'Age' : [10, 12, 14], 'Group' : ['A', 'B', 'B'] }) 418 | 419 | In: summary(df.Age) 420 | Out: 421 | count 3 422 | missing 0 423 | min 10 424 | median 12 425 | mean 12 426 | std 2 427 | max 14 428 | dtype: float64 429 | 430 | In: summary(df.Group) 431 | Out: 432 | count freq 433 | B 2 0.666667 434 | A 1 0.333333 435 | 436 | In:summary(df.Age, by=df.Group) 437 | Out count missing min median mean std max 438 | A 1 0 10 10 10 NaN 10 439 | B 2 0 12 13 13 1.414214 14 440 | """ 441 | if type(data) == pd.core.series.Series: 442 | summ = _summary_calc(data, by=by) 443 | return summ 444 | 445 | elif type(data) == pd.core.frame.DataFrame: 446 | for column in data: 447 | summ = _summary_calc(data[column], by=None) 448 | print('----------------------------------') 449 | print(column, '\n') 450 | print(summ) 451 | 452 | 453 | def diagnostic_accuracy(table, display=True): 454 | """ 455 | Calculates the sensitivity, specificity, negative and positive predictive values 456 | of a 2x2 table with 95% confidence intervals. Note that confidence intervals 457 | are made based on a normal approximation, and may not be appropriate for 458 | small sample sizes. 459 | 460 | PARAMETERS 461 | ---------------------- 462 | table = accepts pandas dataframe, numpy array, or list in [a, b, c, d] format. 463 | 464 | RETURNS 465 | ---------------------- 466 | returns and prints diagnostic accuracy estimates and tuple of 95% confidence interval 467 | 468 | Author: Eric Lofgren 469 | """ 470 | a, b, c, d = _ordered_table(table) 471 | 472 | sen = (a/(a+c)) 473 | sen_se = np.sqrt((sen*(1-sen))/(a+c)) 474 | sen_ci = (sen-(1.96*sen_se),sen+(1.96*sen_se)) 475 | spec = (d/(b+d)) 476 | spec_se = np.sqrt((spec*(1-spec))/(b+d)) 477 | spec_ci = (spec-(1.96*spec_se),spec+(1.96*spec_se)) 478 | PPV = (a/(a+b)) 479 | PPV_se = np.sqrt((PPV*(1-PPV))/(a+b)) 480 | PPV_ci = (PPV-(1.96*PPV_se),PPV+(1.96*PPV_se)) 481 | NPV = (d/(c+d)) 482 | NPV_se = np.sqrt((NPV*(1-NPV))/(c+d)) 483 | NPV_ci = (NPV-(1.96*NPV_se),NPV+(1.96*NPV_se)) 484 | 485 | if display is not False: 486 | print('Sensitivity: {} (95% CI: {})\n'.format(round(sen, 2), sen_ci)) 487 | print('Specificity: {} (95% CI: {})\n'.format(round(spec, 2), spec_ci)) 488 | print('Positive Predictive Value: {} (95% CI: {})\n'.format(round(PPV, 2), PPV_ci)) 489 | print('Negative Predictive Value: {} (95% CI: {})\n'.format(round(NPV, 2), NPV_ci)) 490 | 491 | return sen,sen_ci,spec,spec_ci,PPV,PPV_ci,NPV,NPV_ci 492 | 493 | 494 | def kappa_agreement(table, display=True): 495 | """ 496 | Calculated an unweighted Cohen's kappa statistic of observer agreement for a 2x2 table. 497 | Note that the kappa statistic can be extended to an n x m table, but this 498 | implementation is restricted to 2x2. 499 | 500 | PARAMETERS 501 | ---------------------- 502 | table = accepts pandas dataframe, numpy array, or list in [a, b, c, d] format. 503 | 504 | RETURNS 505 | ---------------------- 506 | returns and prints the Kappa statistic 507 | 508 | Author: Eric Lofgren 509 | """ 510 | a, b, c, d = _ordered_table(table) 511 | n = a + b + c + d 512 | pr_a = ((a+d)/n) 513 | pr_e = (((a+b)/n) * ((a+c)/n)) + (((c+d)/n) * ((b+d)/n)) 514 | k = (pr_a - pr_e)/(1 - pr_e) 515 | if display is not False: 516 | print("Cohen's Kappa: {}\n").format(round(k, 2)) 517 | 518 | return k 519 | --------------------------------------------------------------------------------