├── Procfile ├── .gitignore ├── settings ├── about.py └── config.py ├── README.md ├── application ├── static │ ├── logo.PNG │ └── favicon.ico └── dash.py ├── run.py ├── requirements.txt └── python ├── data.py ├── model.py └── result.py /Procfile: -------------------------------------------------------------------------------- 1 | web gunicorn run:app --preload --workers 1 2 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__/ 2 | .spyproject/ 3 | *.DS_Store 4 | 5 | -------------------------------------------------------------------------------- /settings/about.py: -------------------------------------------------------------------------------- 1 | 2 | txt = "Select a country and see the forecast" -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Dash App for Covid-19 Forecast 2 | 3 | https://app-virus-forecaster.herokuapp.com/ 4 | -------------------------------------------------------------------------------- /application/static/logo.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mdipietro09/App_VirusForecaster/HEAD/application/static/logo.PNG -------------------------------------------------------------------------------- /application/static/favicon.ico: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mdipietro09/App_VirusForecaster/HEAD/application/static/favicon.ico -------------------------------------------------------------------------------- /run.py: -------------------------------------------------------------------------------- 1 | ############################################################################### 2 | # RUN MAIN # 3 | ############################################################################### 4 | 5 | from application.dash import app 6 | from settings import config 7 | 8 | 9 | 10 | app.run_server(debug=config.debug, host=config.host, port=config.port) -------------------------------------------------------------------------------- /settings/config.py: -------------------------------------------------------------------------------- 1 | 2 | import os 3 | 4 | 5 | 6 | ## App settings 7 | name = "Virus Forecaster" 8 | 9 | host = "0.0.0.0" 10 | 11 | port = int(os.environ.get("PORT", 5000)) 12 | 13 | debug = False 14 | 15 | contacts = "https://www.linkedin.com/in/mauro-di-pietro-56a1366b/" 16 | 17 | code = "https://github.com/mdipietro09/App_VirusForecaster" 18 | 19 | tutorial = "https://towardsdatascience.com/how-to-embed-bootstrap-css-js-in-your-python-dash-app-8d95fc9e599e" 20 | 21 | fontawesome = 'https://cdnjs.cloudflare.com/ajax/libs/font-awesome/4.7.0/css/font-awesome.min.css' 22 | 23 | 24 | 25 | ## File system 26 | root = os.path.dirname(os.path.dirname(__file__)) + "/" 27 | 28 | 29 | 30 | ## DB 31 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | attrs==19.3.0 2 | backcall==0.1.0 3 | bleach==3.1.0 4 | certifi==2019.11.28 5 | click==7.1.1 6 | colorama==0.4.3 7 | cycler==0.10.0 8 | dash==1.11.0 9 | dash-bootstrap-components==0.9.2 10 | dash-core-components==1.9.1 11 | dash-html-components==1.0.3 12 | dash-renderer==1.4.0 13 | dash-table==4.6.2 14 | decorator==4.4.2 15 | defusedxml==0.6.0 16 | entrypoints==0.3 17 | Flask==1.1.1 18 | Flask-Compress==1.4.0 19 | future==0.18.2 20 | gunicorn==20.0.4 21 | importlib-metadata==1.5.0 22 | ipykernel==5.1.4 23 | ipython==7.13.0 24 | ipython-genutils==0.2.0 25 | itsdangerous==1.1.0 26 | jedi==0.16.0 27 | Jinja2==2.11.1 28 | jsonschema==3.2.0 29 | jupyter-client==6.1.2 30 | jupyter-core==4.6.1 31 | kiwisolver==1.1.0 32 | MarkupSafe==1.1.1 33 | matplotlib==3.2.1 34 | mistune==0.8.4 35 | nbconvert==5.6.1 36 | nbformat==5.0.4 37 | notebook==6.0.3 38 | numpy==1.18.2 39 | pandas==1.0.3 40 | pandocfilters==1.4.2 41 | parso==0.6.2 42 | pickleshare==0.7.5 43 | plotly==4.6.0 44 | prometheus-client==0.7.1 45 | prompt-toolkit==3.0.4 46 | Pygments==2.6.1 47 | pyparsing==2.4.6 48 | pyrsistent==0.16.0 49 | python-dateutil==2.8.1 50 | pytz==2019.3 51 | pywinpty==0.5.7 52 | pyzmq==18.1.1 53 | retrying==1.3.3 54 | scipy==1.4.1 55 | Send2Trash==1.5.0 56 | six==1.14.0 57 | terminado==0.8.3 58 | testpath==0.4.4 59 | tornado==6.0.4 60 | traitlets==4.3.3 61 | wcwidth==0.1.9 62 | webencodings==0.5.1 63 | Werkzeug==1.0.0 64 | wincertstore==0.2 65 | zipp==2.2.0 66 | -------------------------------------------------------------------------------- /python/data.py: -------------------------------------------------------------------------------- 1 | 2 | import pandas as pd 3 | 4 | 5 | 6 | class Data(): 7 | 8 | def get_data(self): 9 | self.dtf_cases = pd.read_csv("https://raw.githubusercontent.com/CSSEGISandData/COVID-19/master/csse_covid_19_data/csse_covid_19_time_series/time_series_covid19_confirmed_global.csv", sep=",") 10 | self.dtf_deaths = pd.read_csv("https://raw.githubusercontent.com/CSSEGISandData/COVID-19/master/csse_covid_19_data/csse_covid_19_time_series/time_series_covid19_deaths_global.csv", sep=",") 11 | #self.geo = self.dtf_cases[['Country/Region','Lat','Long']].drop_duplicates("Country/Region", keep='first') 12 | self.countrylist = ["World"] + self.dtf_cases["Country/Region"].unique().tolist() 13 | 14 | 15 | @staticmethod 16 | def group_by_country(dtf, country): 17 | dtf = dtf.drop(['Province/State','Lat','Long'], axis=1).groupby("Country/Region").sum().T 18 | dtf["World"] = dtf.sum(axis=1) 19 | dtf = dtf[country] 20 | dtf.index = pd.to_datetime(dtf.index, infer_datetime_format=True) 21 | ts = pd.DataFrame(index=dtf.index, data=dtf.values, columns=["data"]) 22 | return ts 23 | 24 | 25 | @staticmethod 26 | def calculate_mortality(ts_deaths, ts_cases): 27 | last_deaths = ts_deaths["data"].iloc[-1] 28 | last_cases = ts_cases["data"].iloc[-1] 29 | mortality = last_deaths / last_cases 30 | return mortality 31 | 32 | 33 | def process_data(self, country): 34 | self.dtf = self.group_by_country(self.dtf_cases, country) 35 | deaths = self.group_by_country(self.dtf_deaths, country) 36 | self.dtf["deaths"] = deaths 37 | self.mortality = self.calculate_mortality(deaths, self.dtf) -------------------------------------------------------------------------------- /python/model.py: -------------------------------------------------------------------------------- 1 | 2 | import pandas as pd 3 | import numpy as np 4 | from scipy import optimize 5 | 6 | 7 | 8 | class Model(): 9 | 10 | def __init__(self, dtf): 11 | self.dtf = dtf 12 | 13 | 14 | @staticmethod 15 | def f(X, c, k, m): 16 | y = c / (1 + np.exp(-k*(X-m))) 17 | return y 18 | 19 | 20 | @staticmethod 21 | def fit_parametric(X, y, f, p0): 22 | model, cov = optimize.curve_fit(f, X, y, maxfev=10000, p0=p0) 23 | return model 24 | 25 | 26 | @staticmethod 27 | def forecast_parametric(model, f, X): 28 | preds = f(X, model[0], model[1], model[2]) 29 | return preds 30 | 31 | 32 | @staticmethod 33 | def generate_indexdate(start): 34 | index = pd.date_range(start=start, periods=30, freq="D") 35 | index = index[1:] 36 | return index 37 | 38 | 39 | @staticmethod 40 | def add_diff(dtf): 41 | ## create delta columns 42 | dtf["delta_data"] = dtf["data"] - dtf["data"].shift(1) 43 | dtf["delta_forecast"] = dtf["forecast"] - dtf["forecast"].shift(1) 44 | 45 | ## fill Nas 46 | dtf["delta_data"] = dtf["delta_data"].fillna(method='bfill') 47 | dtf["delta_forecast"] = dtf["delta_forecast"].fillna(method='bfill') 48 | 49 | ## interpolate outlier 50 | idx = dtf[pd.isnull(dtf["data"])]["delta_forecast"].index[0] 51 | posx = dtf.index.tolist().index(idx) 52 | posx_a = posx - 1 53 | posx_b = posx + 1 54 | dtf["delta_forecast"].iloc[posx] = (dtf["delta_forecast"].iloc[posx_a] + dtf["delta_forecast"].iloc[posx_b])/2 55 | return dtf 56 | 57 | 58 | def forecast(self): 59 | ## fit 60 | y = self.dtf["data"].values 61 | t = np.arange(len(y)) 62 | model = self.fit_parametric(t, y, self.f, p0=[np.max(y), 1, 1]) 63 | fitted = self.f(t, model[0], model[1], model[2]) 64 | self.dtf["forecast"] = fitted 65 | 66 | ## forecast 67 | t_ahead = np.arange(len(y)+1, len(y)+30) 68 | forecast = self.forecast_parametric(model, self.f, t_ahead) 69 | 70 | ## create dtf 71 | self.today = self.dtf.index[-1] 72 | idxdates = self.generate_indexdate(start=self.today) 73 | preds = pd.DataFrame(data=forecast, index=idxdates, columns=["forecast"]) 74 | self.dtf = self.dtf.append(preds) 75 | 76 | ## add diff 77 | self.dtf = self.add_diff(self.dtf) 78 | 79 | 80 | def add_deaths(self, mortality): 81 | self.dtf["deaths"] = self.dtf[["deaths","forecast"]].apply(lambda x: 82 | mortality*x[1] if np.isnan(x[0]) else x[0], 83 | axis=1) -------------------------------------------------------------------------------- /python/result.py: -------------------------------------------------------------------------------- 1 | 2 | import pandas as pd 3 | import plotly.graph_objects as go 4 | 5 | 6 | 7 | class Result(): 8 | 9 | def __init__(self, dtf): 10 | self.dtf = dtf 11 | 12 | 13 | @staticmethod 14 | def calculate_peak(dtf): 15 | data_max = dtf["delta_data"].max() 16 | forecast_max = dtf["delta_forecast"].max() 17 | if data_max >= forecast_max: 18 | peak_day = dtf[dtf["delta_data"]==data_max].index[0] 19 | return peak_day, data_max 20 | else: 21 | peak_day = dtf[dtf["delta_forecast"]==forecast_max].index[0] 22 | return peak_day, forecast_max 23 | 24 | 25 | @staticmethod 26 | def calculate_max(dtf): 27 | total_cases_until_today = dtf["data"].max() 28 | total_cases_in_30days = dtf["forecast"].max() 29 | active_cases_today = dtf["delta_data"].max() 30 | active_cases_in_30days = dtf["delta_forecast"].max() 31 | return total_cases_until_today, total_cases_in_30days, active_cases_today, active_cases_in_30days 32 | 33 | 34 | def plot_total(self, today): 35 | ## main plots 36 | fig = go.Figure() 37 | fig.add_trace(go.Scatter(x=self.dtf.index, y=self.dtf["data"], mode='markers', name='data', line={"color":"black"})) 38 | fig.add_trace(go.Scatter(x=self.dtf.index, y=self.dtf["forecast"], mode='none', name='forecast', fill='tozeroy')) 39 | fig.add_trace(go.Bar(x=self.dtf.index, y=self.dtf["deaths"], name='deaths', marker_color='red')) 40 | ## add slider 41 | fig.update_xaxes(rangeslider_visible=True) 42 | ## set background color 43 | fig.update_layout(plot_bgcolor='white', autosize=False, width=1000, height=550) 44 | ## add vline 45 | fig.add_shape({"x0":today, "x1":today, "y0":0, "y1":self.dtf["forecast"].max(), 46 | "type":"line", "line":{"width":2,"dash":"dot"} }) 47 | fig.add_trace(go.Scatter(x=[today], y=[self.dtf["forecast"].max()], text=["today"], mode="text", line={"color":"green"}, showlegend=False)) 48 | return fig 49 | 50 | 51 | def plot_active(self, today): 52 | ## main plots 53 | fig = go.Figure() 54 | fig.add_trace(go.Bar(x=self.dtf.index, y=self.dtf["delta_data"], name='data', marker_color='black')) 55 | fig.add_trace(go.Scatter(x=self.dtf.index, y=self.dtf["delta_forecast"], mode='none', name='forecast', fill='tozeroy')) 56 | ## add slider 57 | fig.update_xaxes(rangeslider_visible=True) 58 | ## set background color 59 | fig.update_layout(plot_bgcolor='white', autosize=False, width=1000, height=550) 60 | ## add vline 61 | fig.add_shape({"x0":today, "x1":today, "y0":0, "y1":self.dtf["delta_forecast"].max(), 62 | "type":"line", "line":{"width":2,"dash":"dot"} }) 63 | fig.add_trace(go.Scatter(x=[today], y=[self.dtf["delta_forecast"].max()], text=["today"], mode="text", line={"color":"green"}, showlegend=False)) 64 | return fig 65 | 66 | 67 | def get_panel(self): 68 | peak_day, num_max = self.calculate_peak(self.dtf) 69 | total_cases_until_today, total_cases_in_30days, active_cases_today, active_cases_in_30days = self.calculate_max(self.dtf) 70 | return peak_day, num_max, total_cases_until_today, total_cases_in_30days, active_cases_today, active_cases_in_30days -------------------------------------------------------------------------------- /application/dash.py: -------------------------------------------------------------------------------- 1 | ############################################################################### 2 | # MAIN # 3 | ############################################################################### 4 | 5 | # Setup 6 | import dash 7 | from dash.dependencies import Input, Output, State 8 | import dash_core_components as dcc 9 | import dash_html_components as html 10 | import dash_bootstrap_components as dbc 11 | 12 | from settings import config, about 13 | from python.data import Data 14 | from python.model import Model 15 | from python.result import Result 16 | 17 | 18 | 19 | # Read data 20 | data = Data() 21 | data.get_data() 22 | 23 | 24 | 25 | # App Instance 26 | app = dash.Dash(name=config.name, assets_folder=config.root+"/application/static", external_stylesheets=[dbc.themes.LUX, config.fontawesome]) 27 | app.title = config.name 28 | 29 | 30 | 31 | # Navbar 32 | navbar = dbc.Nav(className="nav nav-pills", children=[ 33 | ## logo/home 34 | dbc.NavItem(html.Img(src=app.get_asset_url("logo.PNG"), height="40px")), 35 | ## about 36 | dbc.NavItem(html.Div([ 37 | dbc.NavLink("About", href="/", id="about-popover", active=False), 38 | dbc.Popover(id="about", is_open=False, target="about-popover", children=[ 39 | dbc.PopoverHeader("How it works"), dbc.PopoverBody(about.txt) 40 | ]) 41 | ])), 42 | ## links 43 | dbc.DropdownMenu(label="Links", nav=True, children=[ 44 | dbc.DropdownMenuItem([html.I(className="fa fa-linkedin"), " Contacts"], href=config.contacts, target="_blank"), 45 | dbc.DropdownMenuItem([html.I(className="fa fa-github"), " Code"], href=config.code, target="_blank"), 46 | dbc.DropdownMenuItem([html.I(className="fa fa-medium"), " Tutorial"], href=config.tutorial, target="_blank") 47 | ]) 48 | ]) 49 | 50 | 51 | 52 | # Input 53 | inputs = dbc.FormGroup([ 54 | html.H4("Select Country"), 55 | dcc.Dropdown(id="country", options=[{"label":x,"value":x} for x in data.countrylist], value="World") 56 | ]) 57 | 58 | 59 | 60 | # App Layout 61 | app.layout = dbc.Container(fluid=True, children=[ 62 | ## Top 63 | html.H1(config.name, id="nav-pills"), 64 | navbar, 65 | html.Br(),html.Br(),html.Br(), 66 | 67 | ## Body 68 | dbc.Row([ 69 | ### input + panel 70 | dbc.Col(md=3, children=[ 71 | inputs, 72 | html.Br(),html.Br(),html.Br(), 73 | html.Div(id="output-panel") 74 | ]), 75 | ### plots 76 | dbc.Col(md=9, children=[ 77 | dbc.Col(html.H4("Forecast 30 days from today"), width={"size":6,"offset":3}), 78 | dbc.Tabs(className="nav nav-pills", children=[ 79 | dbc.Tab(dcc.Graph(id="plot-total"), label="Total cases"), 80 | dbc.Tab(dcc.Graph(id="plot-active"), label="Active cases") 81 | ]) 82 | ]) 83 | ]) 84 | ]) 85 | 86 | 87 | 88 | # Python functions for about navitem-popover 89 | @app.callback(output=Output("about","is_open"), inputs=[Input("about-popover","n_clicks")], state=[State("about","is_open")]) 90 | def about_popover(n, is_open): 91 | if n: 92 | return not is_open 93 | return is_open 94 | 95 | @app.callback(output=Output("about-popover","active"), inputs=[Input("about-popover","n_clicks")], state=[State("about-popover","active")]) 96 | def about_active(n, active): 97 | if n: 98 | return not active 99 | return active 100 | 101 | 102 | 103 | # Python function to plot total cases 104 | @app.callback(output=Output("plot-total","figure"), inputs=[Input("country","value")]) 105 | def plot_total_cases(country): 106 | data.process_data(country) 107 | model = Model(data.dtf) 108 | model.forecast() 109 | model.add_deaths(data.mortality) 110 | result = Result(model.dtf) 111 | return result.plot_total(model.today) 112 | 113 | 114 | 115 | # Python function to plot active cases 116 | @app.callback(output=Output("plot-active","figure"), inputs=[Input("country","value")]) 117 | def plot_active_cases(country): 118 | data.process_data(country) 119 | model = Model(data.dtf) 120 | model.forecast() 121 | model.add_deaths(data.mortality) 122 | result = Result(model.dtf) 123 | return result.plot_active(model.today) 124 | 125 | 126 | 127 | # Python function to render output panel 128 | @app.callback(output=Output("output-panel","children"), inputs=[Input("country","value")]) 129 | def render_output_panel(country): 130 | data.process_data(country) 131 | model = Model(data.dtf) 132 | model.forecast() 133 | model.add_deaths(data.mortality) 134 | result = Result(model.dtf) 135 | peak_day, num_max, total_cases_until_today, total_cases_in_30days, active_cases_today, active_cases_in_30days = result.get_panel() 136 | peak_color = "white" if model.today > peak_day else "red" 137 | panel = html.Div([ 138 | html.H4(country), 139 | dbc.Card(body=True, className="text-white bg-primary", children=[ 140 | 141 | html.H6("Total cases until today:", style={"color":"white"}), 142 | html.H3("{:,.0f}".format(total_cases_until_today), style={"color":"white"}), 143 | 144 | html.H6("Total cases in 30 days:", className="text-danger"), 145 | html.H3("{:,.0f}".format(total_cases_in_30days), className="text-danger"), 146 | 147 | html.H6("Active cases today:", style={"color":"white"}), 148 | html.H3("{:,.0f}".format(active_cases_today), style={"color":"white"}), 149 | 150 | html.H6("Active cases in 30 days:", className="text-danger"), 151 | html.H3("{:,.0f}".format(active_cases_in_30days), className="text-danger"), 152 | 153 | html.H6("Peak day:", style={"color":peak_color}), 154 | html.H3(peak_day.strftime("%Y-%m-%d"), style={"color":peak_color}), 155 | html.H6("with {:,.0f} cases".format(num_max), style={"color":peak_color}) 156 | 157 | ]) 158 | ]) 159 | return panel --------------------------------------------------------------------------------