├── .gitignore ├── LICENSE ├── README.md ├── SIRmodel_cl_op.py └── SIRmodel_cl_op_sn.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # celery beat schedule file 95 | celerybeat-schedule 96 | 97 | # SageMath parsed files 98 | *.sage.py 99 | 100 | # Environments 101 | .env 102 | .venv 103 | env/ 104 | venv/ 105 | ENV/ 106 | env.bak/ 107 | venv.bak/ 108 | 109 | # Spyder project settings 110 | .spyderproject 111 | .spyproject 112 | 113 | # Rope project settings 114 | .ropeproject 115 | 116 | # mkdocs documentation 117 | /site 118 | 119 | # mypy 120 | .mypy_cache/ 121 | .dmypy.json 122 | dmypy.json 123 | 124 | # Pyre type checker 125 | .pyre/ 126 | 127 | # custom file exclusions 128 | SIRmodel.py 129 | SIRmodel_prof.txt 130 | SIRmodel_cl_op.py.lprof 131 | SIRmodel_cl_op_sn.py.lprof 132 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 prithvidiamond1 (R R PRITHVIRAJ) 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # SIR model based epidemiology curve generator in Matplotlib 2 | "Who needs Geogebra(no offense), when we got python!" - Me 3 | 4 | A simple script developed by me to simulate an epidemic using the SIR model for epidemiology. 5 | I have optimised the script to the best of my knowledge. I have included some basic comments at important parts so as to not leave those who need it, hanging. 6 | 7 | Note: I initially decided to leave the optimisation as is but then changed my mind. With the help of reddit users such [u/vlovero](https://www.reddit.com/user/vlovero) and [u/kokoistheway](https://www.reddit.com/user/kokoistheway) who helped me improve the code and speed up things by introducing me to [numba](http://numba.pydata.org/) (special thanks to u/vlovero for that!). Also some of you [scipy](https://docs.scipy.org/doc/scipy/reference/index.html) fans maybe wondering why I didn't use scipy's [integrate.solve\_ivp](https://docs.scipy.org/doc/scipy/reference/generated/scipy.integrate.solve_ivp.html), well I tried that (again all thanks to u/vlovero for bringing that to my attention) but I was not satisfied with it as it generated data that caused a curve that was less smooth to be generated(Don't get me wrong but the solution was still correct as both my original solution and that of scipy's were similar), so I ended up sticking with my original solution. 8 | 9 | Here is the original reddit thread of me asking for help on optimisation if you are interested in that... [Click here!](https://www.reddit.com/r/learnpython/comments/fropmx/very_poor_performance_in_my_matplotlib_script_at/) 10 | 11 | Also I am leaving here both versions, my own optimised version (SIRmodel\_cl\_op.py) and that with the help I mentioned above (SIRmodel\_cl\_op\_sn.py) for those who are interested in the differences and who would love to learn from and play around with both! 12 | 13 | This project was inspired by the Numberphile video on explaining an epidemic's mathematical 14 | modeling. I suggest anybody who doesn't already know what this is about to check the video out! 15 | 16 | Video link: [Click here!](https://www.youtube.com/watch?v=k6nLfCbAzgo) 17 | 18 | So enjoy and feel free to play around with my script! Also stay safe! (For anybody from the future, this was made during the Coronavirus Pandemic) 19 | 20 | (You looking for a license? Well there wasn't any until a Reddit user told me to put one... So here it is MIT LICENSE, Enjoy!) 21 | -------------------------------------------------------------------------------- /SIRmodel_cl_op.py: -------------------------------------------------------------------------------- 1 | 2 | import numpy as np 3 | import matplotlib.pyplot as plt 4 | from matplotlib.widgets import Slider, Button 5 | import matplotlib.patches as patches 6 | 7 | p = 1 #population 8 | i = 0.01*p #infected 9 | s = p-i #susceptible 10 | r = 0 #recovered/removed 11 | 12 | a = 3.2 #transmission parameter 13 | b = 0.23 #recovery parameter 14 | 15 | initialTime = 0 16 | deltaTime = 0.001 #smaller the delta, better the approximation to a real derivative 17 | maxTime = 10000 #more number of points, better is the curve generated 18 | 19 | def sPrime(oldS, oldI, transmissionRate): #differential equations being expressed as functions to 20 | return -1*((transmissionRate*oldS*oldI)/p) #calculate rate of change between time intervals of the 21 | #different quantities i.e susceptible, infected and recovered/removed 22 | def iPrime(oldS, oldI, transmissionRate, recoveryRate): 23 | return (((transmissionRate*oldS)/p)-recoveryRate)*oldI 24 | 25 | def rPrime(oldI, recoveryRate): 26 | return recoveryRate*oldI 27 | 28 | maxTimeInitial = maxTime 29 | 30 | def genData(transRate, recovRate, maxT): 31 | global a, b, maxTimeInitial 32 | a = transRate 33 | b = recovRate 34 | maxTimeInitial = maxT 35 | 36 | sInitial = s 37 | iInitial = i 38 | rInitial = r 39 | 40 | time = np.arange(maxTimeInitial+1) 41 | sVals = np.zeros(maxTimeInitial+1) 42 | iVals = np.zeros(maxTimeInitial+1) 43 | rVals = np.zeros(maxTimeInitial+1) 44 | 45 | for t in range(initialTime, maxTimeInitial+1): #generating the data through a loop 46 | sVals[t] = sInitial 47 | iVals[t] = iInitial 48 | rVals[t] = rInitial 49 | 50 | newDeltas = (sPrime(sInitial, iInitial, transmissionRate=a), iPrime(sInitial, iInitial, transmissionRate=a, recoveryRate=b), rPrime(iInitial, recoveryRate=b)) 51 | sInitial += newDeltas[0]*deltaTime 52 | iInitial += newDeltas[1]*deltaTime 53 | rInitial += newDeltas[2]*deltaTime 54 | 55 | if sInitial < 0 or iInitial < 0 or rInitial < 0: #as soon as any of these value become negative, the data generated becomes invalid 56 | break #according to the SIR model, we assume all values of S, I and R are always positive. 57 | 58 | return (time, sVals, iVals, rVals) 59 | 60 | fig, ax = plt.subplots() 61 | plt.subplots_adjust(bottom=0.4, top=0.94) 62 | 63 | plt.title('SIR epidemiology curves for a disease') 64 | 65 | plt.xlim(0, maxTime+1) 66 | plt.ylim(0, p*1.4) 67 | 68 | plt.xlabel('Time (t)') 69 | plt.ylabel('Population (p)') 70 | 71 | initialData = genData(a, b, maxTimeInitial) 72 | 73 | susceptible, = ax.plot(initialData[0], initialData[1], label='Susceptible', color='b') 74 | infected, = ax.plot(initialData[0], initialData[2], label='Infected', color='r') 75 | recovered, = ax.plot(initialData[0], initialData[3], label='Recovered/Removed', color='g') 76 | 77 | plt.legend() 78 | 79 | transmissionAxes = plt.axes([0.125, 0.25, 0.775, 0.03], facecolor='white') 80 | recoveryAxes = plt.axes([0.125, 0.2, 0.775, 0.03], facecolor='white') 81 | timeAxes = plt.axes([0.125, 0.15, 0.775, 0.03], facecolor='white') 82 | 83 | transmissionSlider = Slider(transmissionAxes, 'Transmission parameter', 0, 10, valinit=a, valstep=0.01) 84 | recoverySlider = Slider(recoveryAxes, 'Recovery parameter', 0, 10, valinit=b, valstep=0.01) 85 | timeSlider = Slider(timeAxes, 'Max time', 0, 100000, valinit=maxTime, valstep=1, valfmt="%i") 86 | 87 | def updateTransmission(newVal): 88 | newData = genData(newVal, b, maxTimeInitial) 89 | 90 | susceptible.set_ydata(newData[1]) 91 | infected.set_ydata(newData[2]) 92 | recovered.set_ydata(newData[3]) 93 | 94 | r_o.set_text(r'$R_O$={:.2f}'.format(a/b)) 95 | 96 | fig.canvas.draw_idle() 97 | 98 | def updateRecovery(newVal): 99 | newData = genData(a, newVal, maxTimeInitial) 100 | 101 | susceptible.set_ydata(newData[1]) 102 | infected.set_ydata(newData[2]) 103 | recovered.set_ydata(newData[3]) 104 | 105 | r_o.set_text(r'$R_O$={:.2f}'.format(a/b)) 106 | 107 | fig.canvas.draw_idle() 108 | 109 | def updateMaxTime(newVal): 110 | global susceptible, infected, recovered 111 | 112 | newData = genData(a, b, int(newVal.item())) 113 | 114 | del ax.lines[:3] 115 | 116 | susceptible, = ax.plot(newData[0], newData[1], label='Susceptible', color='b') 117 | infected, = ax.plot(newData[0], newData[2], label='Infected', color='r') 118 | recovered, = ax.plot(newData[0], newData[3], label='Recovered/Removed', color='g') 119 | 120 | transmissionSlider.on_changed(updateTransmission) 121 | recoverySlider.on_changed(updateRecovery) 122 | timeSlider.on_changed(updateMaxTime) 123 | 124 | resetAxes = plt.axes([0.8, 0.025, 0.1, 0.05]) 125 | resetButton = Button(resetAxes, 'Reset', color='white') 126 | 127 | r_o = plt.text(0.1, 1.5, r'$R_O$={:.2f}'.format(a/b), fontsize=12) 128 | 129 | def reset(event): 130 | transmissionSlider.reset() 131 | recoverySlider.reset() 132 | timeSlider.reset() 133 | 134 | resetButton.on_clicked(reset) 135 | 136 | plt.show() 137 | -------------------------------------------------------------------------------- /SIRmodel_cl_op_sn.py: -------------------------------------------------------------------------------- 1 | 2 | from numba import njit, prange 3 | import numpy as np 4 | import matplotlib.pyplot as plt 5 | from matplotlib.widgets import Slider, Button 6 | import matplotlib.patches as patches 7 | 8 | p = 1 #population 9 | i = 0.01*p #infected 10 | s = p-i #susceptible 11 | r = 0 #recovered/removed 12 | 13 | a = 3.2 #transmission parameter 14 | b = 0.23 #recovery parameter 15 | 16 | initialTime = 0 17 | deltaTime = 0.001 #smaller the delta, better the approximation to a real derivative 18 | maxTime = 10000 #more number of points, better is the curve generated 19 | 20 | @njit(nogil=True) 21 | def sPrime(oldS, oldI, transmissionRate): #differential equations being expressed as functions to 22 | return -1*((transmissionRate*oldS*oldI)/p) #calculate rate of change between time intervals of the 23 | 24 | @njit(nogil=True) #different quantities i.e susceptible, infected and recovered/removed 25 | def iPrime(oldS, oldI, transmissionRate, recoveryRate): 26 | return (((transmissionRate*oldS)/p)-recoveryRate)*oldI 27 | 28 | @njit(nogil=True) 29 | def rPrime(oldI, recoveryRate): 30 | return recoveryRate*oldI 31 | 32 | maxTimeInitial = maxTime 33 | 34 | @njit(nogil=True, parallel=True) 35 | def genData(transRate, recovRate, maxT): 36 | sInitial = s 37 | iInitial = i 38 | rInitial = r 39 | 40 | time = np.arange(maxT+1) 41 | sVals = np.zeros(maxT+1) 42 | iVals = np.zeros(maxT+1) 43 | rVals = np.zeros(maxT+1) 44 | 45 | for t in prange(initialTime, maxT+1): #generating the data through a loop 46 | sVals[t] = sInitial 47 | iVals[t] = iInitial 48 | rVals[t] = rInitial 49 | 50 | newDeltas = (sPrime(sInitial, iInitial, transmissionRate=transRate), iPrime(sInitial, iInitial, transmissionRate=transRate, recoveryRate=recovRate), rPrime(iInitial, recoveryRate=recovRate)) 51 | sInitial += newDeltas[0]*deltaTime 52 | iInitial += newDeltas[1]*deltaTime 53 | rInitial += newDeltas[2]*deltaTime 54 | 55 | if sInitial < 0 or iInitial < 0 or rInitial < 0: #as soon as any of these value become negative, the data generated becomes invalid 56 | break #according to the SIR model, we assume all values of S, I and R are always positive. 57 | 58 | return (time, sVals, iVals, rVals) 59 | 60 | fig, ax = plt.subplots() 61 | plt.subplots_adjust(bottom=0.4, top=0.94) 62 | 63 | plt.title('SIR epidemiology curves for a disease') 64 | 65 | plt.xlim(0, maxTime+1) 66 | plt.ylim(0, p*1.4) 67 | 68 | plt.xlabel('Time (t)') 69 | plt.ylabel('Population (p)') 70 | 71 | initialData = genData(a, b, maxTimeInitial) 72 | 73 | susceptible, = ax.plot(initialData[0], initialData[1], label='Susceptible', color='b') 74 | infected, = ax.plot(initialData[0], initialData[2], label='Infected', color='r') 75 | recovered, = ax.plot(initialData[0], initialData[3], label='Recovered/Removed', color='g') 76 | 77 | plt.legend() 78 | 79 | transmissionAxes = plt.axes([0.125, 0.25, 0.775, 0.03], facecolor='white') 80 | recoveryAxes = plt.axes([0.125, 0.2, 0.775, 0.03], facecolor='white') 81 | timeAxes = plt.axes([0.125, 0.15, 0.775, 0.03], facecolor='white') 82 | 83 | transmissionSlider = Slider(transmissionAxes, 'Transmission parameter', 0, 10, valinit=a, valstep=0.01) 84 | recoverySlider = Slider(recoveryAxes, 'Recovery parameter', 0, 10, valinit=b, valstep=0.01) 85 | timeSlider = Slider(timeAxes, 'Max time', 0, 100000, valinit=maxTime, valstep=1, valfmt="%i") 86 | 87 | def updateTransmission(newVal): 88 | global a 89 | a = newVal 90 | 91 | newData = genData(newVal, b, maxTimeInitial) 92 | 93 | susceptible.set_ydata(newData[1]) 94 | infected.set_ydata(newData[2]) 95 | recovered.set_ydata(newData[3]) 96 | 97 | r_o.set_text(r'$R_O$={:.2f}'.format(a/b)) 98 | 99 | fig.canvas.draw_idle() 100 | 101 | def updateRecovery(newVal): 102 | global b 103 | b = newVal 104 | 105 | newData = genData(a, newVal, maxTimeInitial) 106 | 107 | susceptible.set_ydata(newData[1]) 108 | infected.set_ydata(newData[2]) 109 | recovered.set_ydata(newData[3]) 110 | 111 | r_o.set_text(r'$R_O$={:.2f}'.format(a/b)) 112 | 113 | fig.canvas.draw_idle() 114 | 115 | def updateMaxTime(newVal): 116 | global susceptible, infected, recovered, maxTimeInitial 117 | maxTimeInitial = int(newVal.item()) 118 | 119 | newData = genData(a, b, int(newVal.item())) 120 | 121 | del ax.lines[:3] 122 | 123 | susceptible, = ax.plot(newData[0], newData[1], label='Susceptible', color='b') 124 | infected, = ax.plot(newData[0], newData[2], label='Infected', color='r') 125 | recovered, = ax.plot(newData[0], newData[3], label='Recovered/Removed', color='g') 126 | 127 | transmissionSlider.on_changed(updateTransmission) 128 | recoverySlider.on_changed(updateRecovery) 129 | timeSlider.on_changed(updateMaxTime) 130 | 131 | resetAxes = plt.axes([0.8, 0.025, 0.1, 0.05]) 132 | resetButton = Button(resetAxes, 'Reset', color='white') 133 | 134 | r_o = plt.text(0.1, 1.5, r'$R_O$={:.2f}'.format(a/b), fontsize=12) 135 | 136 | def reset(event): 137 | transmissionSlider.reset() 138 | recoverySlider.reset() 139 | timeSlider.reset() 140 | 141 | resetButton.on_clicked(reset) 142 | 143 | plt.show() 144 | --------------------------------------------------------------------------------