├── .gitignore ├── LICENSE ├── README.md ├── fig ├── sc.png ├── sdid.png ├── sdid2.png ├── sdid3.png └── sdid_plot.png ├── notebook ├── ClassicSCM.ipynb ├── OtherOmegaEstimationMethods.ipynb ├── ReproductionExperiment_CaliforniaSmoking.ipynb ├── Rsythdid.Rmd ├── ScaleTesting_of_DonorPools.ipynb └── fig │ ├── NewHampshire.png │ ├── california.png │ ├── original_lambda.png │ ├── original_omega.png │ └── r_synthdid_result.png ├── sample_data ├── .Rhistory ├── MLAB_data.txt └── README.md ├── setup.cfg ├── setup.py ├── synthdid ├── __init__.py ├── __version__.py ├── model.py ├── optimizer.py ├── plot.py ├── sample_data.py ├── summary.py └── variance.py └── test ├── test_data ├── lambda_CalifolinaSmoking.csv └── omega_CalifolinaSmoking.csv └── test_model.py /.gitignore: -------------------------------------------------------------------------------- 1 | # config 2 | *.ini 3 | 4 | # Byte-compiled / optimized / DLL files 5 | __pycache__/ 6 | *$py.class 7 | 8 | # C extensions 9 | *.so 10 | 11 | # PyInstaller 12 | # Usually these files are written by a python script from a template 13 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 14 | *.manifest 15 | *.spec 16 | 17 | # Installer logs 18 | pip-log.txt 19 | pip-delete-this-directory.txt 20 | 21 | # Unit test / coverage reports 22 | htmlcov/ 23 | .tox/ 24 | .coverage 25 | .coverage.* 26 | .cache 27 | nosetests.xml 28 | coverage.xml 29 | *.cover 30 | .hypothesis/ 31 | .pytest_cache/ 32 | 33 | # Translations 34 | *.mo 35 | *.pot 36 | 37 | # Django stuff: 38 | *.log 39 | local_settings.py 40 | db.sqlite3 41 | 42 | # Flask stuff: 43 | instance/ 44 | .webassets-cache 45 | 46 | # Scrapy stuff: 47 | .scrapy 48 | 49 | # Sphinx documentation 50 | docs/_build/ 51 | 52 | # PyBuilder 53 | target/ 54 | 55 | # Jupyter Notebook 56 | .ipynb_checkpoints 57 | 58 | # pyenv 59 | .python-version 60 | 61 | # celery beat schedule file 62 | celerybeat-schedule 63 | 64 | # SageMath parsed files 65 | *.sage.py 66 | 67 | # Environments 68 | .env 69 | .venv 70 | env/ 71 | venv/ 72 | ENV/ 73 | env.bak/ 74 | venv.bak/ 75 | 76 | # Rope project settings 77 | .ropeproject 78 | 79 | # mkdocs documentation 80 | /site 81 | 82 | # mypy 83 | .mypy_cache/ 84 | 85 | # etc 86 | .DS_Store 87 | .vscode 88 | 89 | .mypy_cache/e 90 | .mypy_cache/ -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # pysynthdid : Synthetic difference in differences for Python 2 | 3 | ## What is Synthetic difference in differences: 4 | ### original paper: 5 | Arkhangelsky, Dmitry, et al. Synthetic difference in differences. No. w25532. National Bureau of Economic Research, 2019. https://www.nber.org/papers/w25532 6 | ### R pkg: 7 | https://github.com/synth-inference/synthdid 8 | 9 | 10 | 11 | 12 | 13 | 14 | ### Blog: 15 | https://medium.com/@masa_asami/causal-inference-using-synthetic-difference-in-differences-with-python-5758e5a76909 16 | 17 | ## Installation: 18 | 19 | ``` 20 | $ pip install git+https://github.com/MasaAsami/pysynthdid 21 | ``` 22 | 23 | This package is still under development. I plan to register with `pypi` after the following specifications are met. 24 | - Refactoring and better documentation 25 | - Completion of the TEST code 26 | 27 | ## How to use: 28 | ### Here's a simple example : 29 | - setup 30 | ```python 31 | from synthdid.model import SynthDID 32 | from synthdid.sample_data import fetch_CaliforniaSmoking 33 | 34 | df = fetch_CaliforniaSmoking() 35 | 36 | PRE_TEREM = [1970, 1988] 37 | POST_TEREM = [1989, 2000] 38 | 39 | TREATMENT = ["California"] 40 | ``` 41 | - estimation & plot 42 | ```python 43 | sdid = SynthDID(df, PRE_TEREM, POST_TEREM, TREATMENT) 44 | sdid.fit(zeta_type="base") 45 | sdid.plot(model="sdid") 46 | ``` 47 | 48 | 49 | - Details of each method will be created later. 50 | ### See the jupyter [`notebook`](https://github.com/MasaAsami/pysynthdid/tree/main/notebook) for basic usage 51 | - `ReproductionExperiment_CaliforniaSmoking.ipynb` 52 | - This is a reproduction experiment note of the original paper, using a famous dataset (CaliforniaSmoking). 53 | 54 | - `OtherOmegaEstimationMethods.ipynb` 55 | - This note is a different take on the estimation method for parameter `omega` (& `zeta` ). As a result, it confirms the robustness of the estimation method in the original paper. 56 | 57 | - `ScaleTesting_of_DonorPools.ipynb` 58 | - In this note, we will check how the estimation results change with changes in the scale of the donor pool features. 59 | - Adding donor pools with extremely different scales (e.g., 10x) can have a significant impact (bias) on the estimates. 60 | - If different scales are mixed, as is usually the case in traditional regression, preprocessing such as logarithmic transformation is likely to be necessary 61 | 62 | ## Discussions and PR: 63 | - This module is still under development. 64 | - If you have any questions or comments, please feel free to use issues. 65 | -------------------------------------------------------------------------------- /fig/sc.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MasaAsami/pysynthdid/01afe33ae22f513c65f9cfdec56a4b21ca547c28/fig/sc.png -------------------------------------------------------------------------------- /fig/sdid.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MasaAsami/pysynthdid/01afe33ae22f513c65f9cfdec56a4b21ca547c28/fig/sdid.png -------------------------------------------------------------------------------- /fig/sdid2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MasaAsami/pysynthdid/01afe33ae22f513c65f9cfdec56a4b21ca547c28/fig/sdid2.png -------------------------------------------------------------------------------- /fig/sdid3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MasaAsami/pysynthdid/01afe33ae22f513c65f9cfdec56a4b21ca547c28/fig/sdid3.png -------------------------------------------------------------------------------- /fig/sdid_plot.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MasaAsami/pysynthdid/01afe33ae22f513c65f9cfdec56a4b21ca547c28/fig/sdid_plot.png -------------------------------------------------------------------------------- /notebook/Rsythdid.Rmd: -------------------------------------------------------------------------------- 1 | --- 2 | title: "Sythdid" 3 | output: 4 | html_document: default 5 | pdf_document: default 6 | word_document: default 7 | --- 8 | 9 | ## setup and estimation 10 | 11 | ```{r setup, include=FALSE} 12 | library(synthdid) 13 | 14 | # Estimate the effect of California Proposition 99 on cigarette consumption 15 | data('california_prop99') 16 | setup = panel.matrices(california_prop99) 17 | ``` 18 | 19 | ```{r tau} 20 | tau.hat = synthdid_estimate(setup$Y, setup$N0, setup$T0) 21 | ``` 22 | 23 | ## Plots 24 | 25 | ```{r pressure, echo=FALSE} 26 | plot(tau.hat) 27 | ``` 28 | 29 | ## Summary 30 | 31 | ```{r tau} 32 | summary(tau.hat) 33 | ``` 34 | -------------------------------------------------------------------------------- /notebook/fig/NewHampshire.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MasaAsami/pysynthdid/01afe33ae22f513c65f9cfdec56a4b21ca547c28/notebook/fig/NewHampshire.png -------------------------------------------------------------------------------- /notebook/fig/california.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MasaAsami/pysynthdid/01afe33ae22f513c65f9cfdec56a4b21ca547c28/notebook/fig/california.png -------------------------------------------------------------------------------- /notebook/fig/original_lambda.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MasaAsami/pysynthdid/01afe33ae22f513c65f9cfdec56a4b21ca547c28/notebook/fig/original_lambda.png -------------------------------------------------------------------------------- /notebook/fig/original_omega.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MasaAsami/pysynthdid/01afe33ae22f513c65f9cfdec56a4b21ca547c28/notebook/fig/original_omega.png -------------------------------------------------------------------------------- /notebook/fig/r_synthdid_result.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MasaAsami/pysynthdid/01afe33ae22f513c65f9cfdec56a4b21ca547c28/notebook/fig/r_synthdid_result.png -------------------------------------------------------------------------------- /sample_data/.Rhistory: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MasaAsami/pysynthdid/01afe33ae22f513c65f9cfdec56a4b21ca547c28/sample_data/.Rhistory -------------------------------------------------------------------------------- /sample_data/MLAB_data.txt: -------------------------------------------------------------------------------- 1 | 1 2 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 3 2 | 9.678973622 9.643623246 9.984357198 10.18803512 9.974561161 9.817172262 9.711300956 10.00688288 9.831646389 9.836926672 9.916982969 9.695040385 9.747586568 9.786902746 9.938785023 9.546848297 9.877391073 9.753027174 9.850039482 10.02443239 10.00636715 9.708400938 9.751609802 9.756118351 9.886301253 9.814118067 9.926845233 9.931006961 9.673460537 9.702802976 9.737283919 9.896063487 9.678585158 9.821148766 9.957432535 9.65476354 9.882993592 9.913661109 10.07655864 3 | 89.34444512 89.8777771 82.62222205 103.4777764 90.05555513 84.36666658 86.07777786 89.83333249 81.08888838 90.65555615 87.78888872 71.48888991 90.0666665 91.89999898 96.28888872 88.92222214 84.46666675 85.5888888 89.48888906 93.24444538 83.39999941 87.48888736 71.44444402 88.90000068 84.55555513 90.47777812 89.17777803 90.22222307 76.61111196 88.54444461 85.17777846 92.47777854 89.4333335 88.02222273 74.78888914 92.58888753 95.15555784 81.00000042 89.42222341 4 | 0.174801901 0.164611373 0.173703247 0.163659688 0.178224497 0.176944127 0.152016721 0.17028118 0.175089656 0.169908937 0.170582652 0.17540008 0.181948922 0.16571967 0.172464155 0.181491156 0.166553891 0.16537521 0.168586676 0.162917763 0.169238773 0.174311015 0.179371097 0.180801443 0.169848301 0.168725929 0.164436671 0.175420049 0.184413918 0.173589869 0.171025995 0.177760008 0.187830499 0.177342362 0.177402943 0.164830512 0.174546391 0.174207053 0.173532382 5 | 18.95999985 18.52000008 25.08000031 20.7 26.08000031 21.75999985 22.22000008 24.74000015 21.97999992 23.23999977 19.94000015 18.94000015 23.87999992 22.44000015 23.12000008 21.14000015 23.88000069 27.87999992 24.71999969 37 34.95999985 27.97999992 19.92000008 23.5 24.02000008 18.13999977 25.07999992 25.54000015 22.9 21.26000023 20.57999992 28.57999992 13.33999996 27.05999985 22.99999962 19.80000038 32.04000015 24.97999992 24.28000031 6 | 112.0999985 121.5 94.59999847 104.8000031 137.1000061 124.0999985 84.5 107.5999985 134 100.1999969 103.1999969 173.1999969 110.9000015 125 94.09999847 109 127.4000015 87.09999847 92.90000153 141.8999939 180.3999939 77.69999695 146 87.09999847 122.4000015 103.5999985 107.5999985 138 124.4000015 91.90000153 125.3000031 96.5 55 128.6999969 129.5 109.0999985 102.5999985 114.3000031 90.09999847 7 | 123.1999969 131.8000031 131 118 150.5 134 115.1999969 135.1999969 146.8999939 124.5999985 127.0999985 215.3000031 143.8000031 141.1999969 117.6999969 127 142.1000061 122 116.3000031 177.6999969 247.8000031 102.6999969 187.8000031 123.6999969 133.5 141.6000061 124 149.3000031 138.3000031 114.6999969 130.3999939 129.6999969 74.80000305 161.6000061 148.8999939 122.3000031 117.5999985 158.1000061 120.1999969 8 | 111.6999969 114.8000031 131 110.1999969 147.6000061 122.9000015 123.3000031 131.8000031 162.3999939 120.5 123.4000015 223 133.6000061 140.6999969 111.5 116.8000031 135.6000061 123.6999969 114.0999985 205.1999969 269.1000061 103.0999985 226 117.9000015 122.5 132.8999939 114.5999985 154.6999969 130.5 113.5 117.4000015 116 75.80000305 155.5 152.6999969 123.1999969 113.5 160.6999969 127.0999985 9 | 89.80000305 100.3000031 124.8000031 120 155 109.9000015 102.4000015 124.8000031 134.6000061 108.5 114 155.8000031 115.9000015 128.5 104.3000031 93.40000153 121.3000031 111.1999969 108.0999985 189.5 265.7000122 90 172.3999939 93.80000305 121.5999985 108.4000015 107.3000031 123.9000015 103.5999985 92.69999695 99.80000305 106.4000015 65.5 122.5999985 124.3000031 114.5 106.4000015 132.1999969 123 10 | 95.40000153 104.0999985 125.5 117.5999985 161.1000061 115.6999969 108.5 125.5999985 139.3000031 108.4000015 102.8000031 163.5 119.8000031 133.1999969 116.4000015 105.4000015 127.5999985 115.5999985 108.5999985 190.5 278 92.59999847 187.6000061 98.5 124.5999985 115.4000015 106.3000031 123.1999969 115 96.69999695 106.3000031 108.9000015 67.69999695 124.4000015 128.3999939 111.5 105.4000015 131.6999969 121 11 | 101.0999985 103.9000015 134.3000031 110.8000031 156.3000031 117 126.0999985 126.5999985 149.1999969 109.4000015 111 179.3999939 125.3000031 136.5 96.80000305 112.0999985 130 122.1999969 104.9000015 198.6000061 296.2000122 99.30000305 214.1000061 103.8000031 124.4000015 121.6999969 109 134.3999939 118.6999969 103 111.5 108.5999985 71.30000305 138 137 117.5 108.8000031 140 123.5 12 | 102.9000015 108 137.8999939 109.3000031 154.6999969 119.8000031 121.8000031 124.4000015 156 110.5999985 115.1999969 201.8999939 126.6999969 138 106.8000031 115 132.1000061 119.9000015 106.5999985 201.5 279 98.90000153 226.5 108.6999969 120.5 124.0999985 110.6999969 142 125.5 103.5 109.6999969 110.4000015 72.69999695 146.8000031 143.1000061 116.5999985 109.5 141.1999969 124.4000015 13 | 108.1999969 109.6999969 132.8000031 112.4000015 151.3000031 123.6999969 125.5999985 131.8999939 159.6000061 116.0999985 118.5999985 212.3999939 129.8999939 142.1000061 110.5999985 117.0999985 135.3999939 121.9000015 110.5 204.6999969 269.7999878 100.3000031 227.3000031 110.5 122.0999985 130.5 114.1999969 146.1000061 129.6999969 108.4000015 114.8000031 114.6999969 75.59999847 151.8000031 149.6000061 119.9000015 111.8000031 145.8000031 126.6999969 14 | 111.6999969 114.8000031 131 110.1999969 147.6000061 122.9000015 123.3000031 131.8000031 162.3999939 120.5 123.4000015 223 133.6000061 140.6999969 111.5 116.8000031 135.6000061 123.6999969 114.0999985 205.1999969 269.1000061 103.0999985 226 117.9000015 122.5 132.8999939 114.5999985 154.6999969 130.5 113.5 117.4000015 116 75.80000305 155.5 152.6999969 123.1999969 113.5 160.6999969 127.0999985 15 | 116.1999969 119.0999985 134.1999969 113.4000015 153 125.9000015 125.0999985 134.3999939 166.6000061 124.4000015 127.6999969 230.8999939 139.6000061 144.8999939 116.6999969 120.9000015 139.5 124.9000015 118.0999985 201.3999939 290.5 102.4000015 230.1999969 125.4000015 124.5999985 138.6000061 118.8000031 150.1999969 136.8000031 116.6999969 121.6999969 121.4000015 77.90000153 171.1000061 158.1000061 129.6999969 115.4000015 161.5 128 16 | 117.0999985 122.5999985 132 117.3000031 153.3000031 127.9000015 125 134 173 125.5 127.9000015 229.3999939 140 145.6000061 117.1999969 122.0999985 140.8000031 127 117.6999969 190.8000031 278.7999878 102.4000015 217 122.1999969 127.3000031 140.3999939 120.0999985 148.8000031 137.1999969 115.5999985 124.5999985 124.1999969 78 169.3999939 157.6999969 133.8999939 117.1999969 160.3999939 126.4000015 17 | 123 127.3000031 129.1999969 117.5 155.5 130.6000061 122.8000031 136.6999969 150.8999939 127.0999985 127.0999985 224.6999969 142.6999969 143.8999939 118.9000015 124.9000015 141.8000031 127.1999969 117.4000015 187 269.6000061 103.0999985 205.5 121.9000015 131.3000031 143.6000061 122.3000031 146.8000031 140.3999939 116.9000015 127.3000031 126.5999985 79.59999847 162.3999939 155.8999939 131.6000061 116.6999969 160.3000031 126.0999985 18 | 121.4000015 126.5 131.5 117.4000015 150.1999969 131 117.5 135.3000031 148.8999939 124.1999969 126.4000015 214.8999939 140.1000061 138.5 118.3000031 123.9000015 140.1999969 120.3000031 116.0999985 183.3000031 254.6000061 101 197.3000031 121.3000031 130.8999939 141.6000061 122.5999985 145.8000031 135.6999969 117.4000015 127.1999969 126.4000015 79.09999847 160.8999939 151.8000031 122.0999985 117.0999985 168.6000061 121.9000015 19 | 123.1999969 131.8000031 131 118 150.5 134 115.1999969 135.1999969 146.8999939 124.5999985 127.0999985 215.3000031 143.8000031 141.1999969 117.6999969 127 142.1000061 122 116.3000031 177.6999969 247.8000031 102.6999969 187.8000031 123.6999969 133.5 141.6000061 124 149.3000031 138.3000031 114.6999969 130.3999939 129.6999969 74.80000305 161.6000061 148.8999939 122.3000031 117.5999985 158.1000061 120.1999969 20 | 119.5999985 128.6999969 133.8000031 116.4000015 152.6000061 131.6999969 114.0999985 133 148.5 132.8999939 132 209.6999969 144 138.8999939 120.8000031 125.3000031 140.5 121.0999985 117 171.8999939 245.3999939 103 179.3000031 125.6999969 132.8000031 143.6999969 125.1999969 151.1999969 136.1000061 115.6999969 129.1000061 129 77.59999847 163.8000031 149.8999939 120.5 119.9000015 163.1000061 118.5999985 21 | 119.0999985 127.4000015 130.5 114.6999969 154.1000061 131.1999969 111.5 130.6999969 147.6999969 116.1999969 130.8999939 210.6000061 143.8999939 139.5 119.4000015 125.8000031 139.6999969 122.4000015 117.0999985 165.1000061 239.8000031 97.5 179 126.8000031 134 147 123.3000031 146.3000031 136 113 131.3999939 131.1999969 73.59999847 162.3000031 147.3999939 119.8000031 115.5999985 157.6999969 115.4000015 22 | 116.3000031 128 125.3000031 114.0999985 149.6000061 128.6000061 111.3000031 127.9000015 143 115.5999985 127.5999985 201.1000061 133.6999969 135.3999939 113.1999969 122.3000031 134.1000061 113.6999969 110.8000031 159.1999969 232.8999939 96.30000305 169.8000031 119.5999985 130 140 125.3000031 135.8000031 131.1000061 109.8000031 129 126.4000015 69 153.8000031 144.6999969 115.6999969 106.3000031 141.1999969 110.8000031 23 | 113 123.0999985 119.6999969 112.5 144 126.3000031 103.5999985 124 137.8000031 111.1999969 121.6999969 183.1999969 128.8999939 135.5 110.8000031 116.4000015 130 110.0999985 107.6999969 136.6000061 215.1000061 88.90000153 160.6000061 109.4000015 127.0999985 128.1000061 115.3000031 136.8999939 127 105.6999969 125.0999985 117.1999969 66.30000305 144.3000031 136.8000031 111.9000015 105.5999985 128.8999939 104.8000031 24 | 114.5 125.8000031 112.4000015 111 144.5 128.8000031 100.6999969 121.5999985 135.3000031 109.4000015 115.6999969 182.3999939 125 127.9000015 113 115.3000031 129.1999969 103.5999985 105.0999985 146.6999969 201.1000061 88 156.3000031 103.1999969 126.6999969 124.1999969 115.8000031 133.3999939 125.4000015 104.4000015 128.6999969 115.9000015 66.5 144.5 134.6000061 109.0999985 107 125.6999969 102.8000031 25 | 116.3000031 126 109.9000015 108.5 142.3999939 129 96.69999695 118.1999969 137.6000061 104.0999985 109.4000015 179.8000031 121.1999969 119 104.3000031 113.1999969 128.8000031 97.80000305 103.0999985 142.6000061 195.8999939 88.19999695 154.3999939 99.80000305 126.3000031 119.9000015 113.9000015 136.3000031 126.5999985 97 129 113.6999969 64.40000153 131.1999969 135.8000031 112.0999985 105.4000015 124.8000031 99.69999695 26 | 114 122.3000031 102.4000015 109 141 129.3000031 95 109.5 134 101.0999985 105.1999969 171.1999969 116.5 125 108.8000031 110 128.6999969 91.69999695 101.3000031 147.6999969 195.1000061 82.30000305 150.5 92.30000305 124.5999985 113.0999985 110.5999985 124.4000015 126.5999985 95.80000305 130.6000061 105.8000031 67.69999695 128.3000031 133 107.5 106 110.4000015 97.5 27 | 112.0999985 121.5 94.59999847 104.8000031 137.1000061 124.0999985 84.5 107.5999985 134 100.1999969 103.1999969 173.1999969 110.9000015 125 94.09999847 109 127.4000015 87.09999847 92.90000153 141.8999939 180.3999939 77.69999695 146 87.09999847 122.4000015 103.5999985 107.5999985 138 124.4000015 91.90000153 125.3000031 96.5 55 128.6999969 129.5 109.0999985 102.5999985 114.3000031 90.09999847 28 | 105.5999985 118.3000031 88.80000305 100.5999985 131.6999969 117.0999985 78.40000153 104.5999985 132.5 94.40000153 96.5 171.6000061 103.5999985 122.4000015 92.30000305 108.3000031 122.8000031 86.19999695 93.80000305 137.8999939 172.8999939 74.40000153 139.3000031 84.09999847 118.5999985 97.5 107.0999985 120.8000031 122.4000015 87.40000153 124.6999969 94.5 57 120.9000015 122.5 104 100.3000031 111.4000015 82.40000153 29 | 108.5999985 113.0999985 87.40000153 91.5 127.1999969 113.8000031 90.09999847 94.09999847 128.3000031 95.40000153 94.30000305 182.5 101.5 117.5 90.69999695 101.8000031 119.0999985 84.69999695 89.90000153 137.3000031 152.3999939 70.80000305 133.6999969 77.09999847 115.5 88.40000153 101.3000031 101.4000015 118.5999985 88.30000305 121.8000031 85.59999847 53.40000153 124.3000031 118.9000015 104.0999985 94 96.90000153 77.80000305 30 | 107.9000015 116.8000031 90.19999695 86.69999695 118.8000031 109.5999985 85.40000153 96.09999847 127.1999969 97.09999847 91.80000305 170.3999939 107.1999969 116.0999985 86.19999695 105.5999985 119.9000015 82.90000153 92.40000153 115.5 144.8000031 69.90000153 132.6999969 85.19999695 113.1999969 87.80000305 102.5 103.5999985 121.5 91.80000305 120.5999985 79.40000153 53.5 120.9000015 109.0999985 100.0999985 95.5 109.0999985 68.69999695 31 | 109.0999985 126 88.30000305 83.5 120 109.1999969 85.09999847 94.80000305 128.1999969 95.19999695 90 167.6000061 108.5 114.5 83.80000305 103.9000015 122.3000031 86.59999847 90.59999847 110 143.6999969 71.40000153 128.8999939 74.30000305 112.3000031 86.30000305 96.19999695 100.0999985 112.8000031 93 121 77.19999695 55 126.5 108.1999969 97.90000153 96.19999695 110.8000031 67.5 32 | 108.5 113.8000031 88.59999847 79.09999847 123.8000031 109.1999969 86.69999695 94.59999847 126.8000031 92.5 89.90000153 167.6000061 106.1999969 108.5 81.59999847 105.4000015 121.5999985 86 91.09999847 108.0999985 148.8999939 69 129.6999969 83 108.9000015 86.19999695 94.69999695 94.09999847 115.1999969 91.59999847 120.8000031 81.30000305 56.20000076 117.1999969 105.4000015 111 91.19999695 108.4000015 63.40000153 33 | 107.0999985 108.8000031 89.09999847 76.59999847 126.0999985 107.8000031 93 85.69999695 128.1999969 93.40000153 89.09999847 170.1000061 105.3000031 101.5999985 83.40000153 106 119.4000015 88.19999695 85.90000153 105.1999969 153.8000031 68.19999695 112.6999969 81 108.5999985 104.8000031 95.40000153 91.90000153 112.1999969 94.80000305 118.8000031 78.80000305 55.79999924 120.3000031 106.1999969 104.1999969 91.80000305 111.1999969 58.59999847 34 | 102.5999985 113 85.40000153 79.30000305 127.1999969 100.3000031 78.19999695 84.30000305 135.3999939 93 90.09999847 175.3000031 105.6999969 102.3000031 84.09999847 107.5 124 90.5 88.5 100.9000015 158.5 67 124.9000015 80.59999847 111.6999969 109.5 95.40000153 90.80000305 109.1999969 98.59999847 125.4000015 75.19999695 52 123.1999969 106.6999969 115.1999969 93.5 115 56.40000153 35 | 101.4000015 110.6999969 83.09999847 76 128.3000031 102.6999969 73.59999847 81.80000305 135.1000061 94 88.69999695 179 106.8000031 100 81.69999695 106.9000015 124.0999985 87.30000305 86.19999695 99 158 65.69999695 129.6999969 80.80000305 107.5999985 110.8000031 93.30000305 87.5 102.9000015 92.30000305 119.1999969 74.59999847 54 102.5 104.5999985 112.6999969 92.09999847 110.3000031 54.5 36 | 104.9000015 108.6999969 81.30000305 75.90000153 124.0999985 100.5999985 75 79.59999847 135.3000031 93.90000153 89.19999695 186.8000031 105.3000031 101.0999985 84.09999847 106.3000031 120.5999985 88.90000153 85.5 95.59999847 174.3999939 61.79999924 125.5999985 77.5 108.5999985 111.8000031 92.90000153 90 124.5 88.80000305 118.9000015 72.59999847 57 97.69999695 108 114.5 91.90000153 108.8000031 53.79999924 37 | 106.1999969 109.5 81.19999695 75.5 132.8000031 100.5 78.90000153 80.30000305 135.8999939 94 87.59999847 171.3000031 103.1999969 94.5 83.19999695 107 120.0999985 89.09999847 83.09999847 102.4000015 173.8000031 62.59999847 126 79.09999847 106.4000015 112.1999969 92.09999847 88.69999695 126.9000015 88.30000305 119.6999969 73.19999695 42.29999924 97 105.5999985 114.5999985 88.69999695 102.9000015 52.29999924 38 | 100.6999969 104.8000031 79.59999847 73.40000153 139.5 97.09999847 75.09999847 72.19999695 133.3000031 91.69999695 83.30000305 165.3000031 101 85.5 80.69999695 103.9000015 118 82.59999847 86.59999847 103.9000015 171.6999969 59.70000076 113.0999985 74.69999695 104 111.4000015 91.09999847 86.90000153 109.4000015 83.5 115.5999985 67.59999847 43.90000153 94.09999847 102.0999985 112.4000015 84.40000153 104.8000031 47.20000076 39 | 96.19999695 99.40000153 73 71.40000153 140.6999969 88.40000153 66.90000153 70 125.5 88.90000153 79.80000305 156.1999969 104.3000031 82.90000153 76 97.19999695 113.8000031 75.5 77.59999847 93.19999695 147.3000031 53.79999924 109 72.5 99.90000153 108.9000015 87.90000153 83.09999847 103.9000015 75.09999847 108.6999969 69.30000305 40.70000076 88.90000153 96.69999695 107.9000015 80.09999847 90.5 41.59999847 40 | -------------------------------------------------------------------------------- /sample_data/README.md: -------------------------------------------------------------------------------- 1 | This package contains the following data sets: 2 | This data is from Abadie, Diamond, and Hainmueller (2010). The raw data is in MATLAB format from https://web.stanford.edu/~jhain/synthpage.html 3 | -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [metadata] 2 | description-file = README.md -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | 4 | from setuptools import setup 5 | 6 | 7 | here = os.path.abspath(os.path.dirname(__file__)) 8 | 9 | install_requires = [ 10 | 'pandas', 11 | 'matplotlib', 12 | 'numpy', 13 | 'tqdm', 14 | 'sklearn', 15 | 'scipy', 16 | 'toolz', 17 | 'bayesian-optimization >= 1.1.0' 18 | ] 19 | tests_require = [ 20 | 'pytest', 21 | 'pytest-cov', 22 | 'mock', 23 | 'tox' 24 | ] 25 | setup_requires = [ 26 | 'flake8', 27 | 'isort' 28 | ] 29 | extras_require = { 30 | 'docs': [ 31 | 'ipython', 32 | 'jupyter' 33 | ] 34 | } 35 | 36 | packages = ['synthdid'] 37 | 38 | _version = {} 39 | _version_path = os.path.join(here, 'synthdid', '__version__.py') 40 | 41 | with open(_version_path, 'r') as f: 42 | exec(f.read(), _version) 43 | 44 | with open('README.md', 'r') as f: 45 | readme = f.read() 46 | 47 | setup( 48 | name='pysynthdid', 49 | version=_version['__version__'], 50 | author='MasaAsami', 51 | author_email='m.asami.moj@gmail.com', 52 | url='https://github.com/MasaAsami/pysynthdid', 53 | description= "Python version of Synthetic difference in differences", 54 | long_description=readme, 55 | long_description_content_type='text/markdown', 56 | packages=packages, 57 | include_package_data=True, 58 | install_requires=install_requires, 59 | tests_require=tests_require, 60 | setup_requires=setup_requires, 61 | extras_require=extras_require, 62 | license='Apache License 2.0', 63 | keywords='causal-inference', 64 | classifiers=[ 65 | 'Intended Audience :: Developers', 66 | 'Intended Audience :: Education', 67 | 'Intended Audience :: Science/Research', 68 | 'License :: OSI Approved :: Apache Software License', 69 | 'Operating System :: Unix', 70 | 'Programming Language :: Python :: 3.6', 71 | 'Programming Language :: Python :: 3.7', 72 | 'Programming Language :: Python :: 3.8', 73 | 'Programming Language :: Python :: 3.9', 74 | 'Topic :: Scientific/Engineering', 75 | ], 76 | project_urls={ 77 | 'Source': 'https://github.com/MasaAsami/pysynthdid' 78 | }, 79 | python_requires='>=3', 80 | test_suite='tests' 81 | ) 82 | -------------------------------------------------------------------------------- /synthdid/__init__.py: -------------------------------------------------------------------------------- 1 | name = "synthdid" 2 | 3 | __all__ = [ 4 | 'model', 5 | '__version__', 6 | 'optimizer', 7 | 'plot', 8 | 'variance', 9 | 'summary' 10 | ] 11 | 12 | from synthdid.__version__ import __version__ 13 | 14 | -------------------------------------------------------------------------------- /synthdid/__version__.py: -------------------------------------------------------------------------------- 1 | # Licensed under the Apache License, Version 2.0 (the "License"); 2 | # you may not use this file except in compliance with the License. 3 | # You may obtain a copy of the License at 4 | # 5 | # http://www.apache.org/licenses/LICENSE-2.0 6 | # 7 | # Unless required by applicable law or agreed to in writing, software 8 | # distributed under the License is distributed on an "AS IS" BASIS, 9 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 10 | # See the License for the specific language governing permissions and 11 | # limitations under the License. 12 | 13 | __version__ = '0.0.1' -------------------------------------------------------------------------------- /synthdid/model.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import numpy as np 3 | 4 | from synthdid.optimizer import Optimize 5 | from synthdid.plot import Plot 6 | from synthdid.variance import Variance 7 | from synthdid.summary import Summary 8 | 9 | 10 | class SynthDID(Optimize, Plot, Variance, Summary): 11 | """ 12 | Synthetic Difference in Differences 13 | df : pandas.DataFrame 14 | pre_term : term before treatment 15 | post_term : term after treatmetn 16 | treatment_unit : treatment columns names list 17 | 18 | [example] 19 | df = fetch_CaliforniaSmoking() 20 | sdid = SynthDID(df, [1970, 1979], [1980, 1988], ["California"]) 21 | sdid.fit() 22 | """ 23 | 24 | def __init__( 25 | self, df, pre_term, post_term, treatment_unit: list, random_seed=0, **kwargs 26 | ): 27 | # first val 28 | self.df = df 29 | self.pre_term = pre_term 30 | self.post_term = post_term 31 | self.random_seed = random_seed 32 | # able to change 33 | self.treatment = treatment_unit 34 | self.control = [col for col in df.columns if col not in self.treatment] 35 | self._divide_data() 36 | 37 | # params 38 | self.hat_zeta = None 39 | self.base_zeta = None 40 | self.hat_omega_ADH = None 41 | self.hat_omega = None 42 | self.hat_lambda = None 43 | self.hat_omega_ElasticNet = None 44 | self.hat_omega_Lasso = None 45 | self.hat_omega_Ridge = None 46 | self.sdid_se = None 47 | self.sc_se = None 48 | self.did_se = None 49 | 50 | def _divide_data(self): 51 | 52 | self.Y_pre_c = self.df.loc[self.pre_term[0] : self.pre_term[1], self.control] 53 | self.Y_pre_t = self.df.loc[self.pre_term[0] : self.pre_term[1], self.treatment] 54 | 55 | self.Y_post_c = self.df.loc[self.post_term[0] : self.post_term[1], self.control] 56 | self.Y_post_t = self.df.loc[ 57 | self.post_term[0] : self.post_term[1], self.treatment 58 | ] 59 | 60 | self.n_treat = len(self.treatment) 61 | self.n_post_term = len(self.Y_post_t) 62 | 63 | def fit( 64 | self, 65 | model="all", 66 | zeta_type="base", 67 | force_zeta=None, 68 | sparce_estimation=False, 69 | cv=5, 70 | cv_split_type="KFold", 71 | candidate_zata=[], 72 | n_candidate=20, 73 | sc_v_model="linear", 74 | additional_X=pd.DataFrame(), 75 | additional_y=pd.DataFrame(), 76 | ): 77 | 78 | self.base_zeta = self.est_zeta(self.Y_pre_c) 79 | 80 | if zeta_type == "base": 81 | self.zeta = self.base_zeta 82 | 83 | elif zeta_type == "grid_search": 84 | self.zeta = self.grid_search_zeta( 85 | cv=cv, 86 | n_candidate=n_candidate, 87 | candidate_zata=candidate_zata, 88 | split_type=cv_split_type, 89 | )[0] 90 | 91 | elif zeta_type == "bayesian_opt": 92 | self.zeta = self.bayes_opt_zeta(cv=cv, split_type=cv_split_type)[0] 93 | 94 | else: 95 | print(f"your choice :{zeta_type} is not supported.") 96 | self.zeta = self.base_zeta 97 | 98 | if force_zeta != None: 99 | self.zeta = force_zeta 100 | 101 | self.hat_omega = self.est_omega(self.Y_pre_c, self.Y_pre_t, self.zeta) 102 | self.hat_omega_ADH = self.est_omega_ADH( 103 | self.Y_pre_c, 104 | self.Y_pre_t, 105 | additional_X=additional_X, 106 | additional_y=additional_y, 107 | ) 108 | self.hat_lambda = self.est_lambda(self.Y_pre_c, self.Y_post_c) 109 | 110 | if sparce_estimation: 111 | self.hat_omega_ElasticNet = self.est_omega_ElasticNet( 112 | self.Y_pre_c, self.Y_pre_t 113 | ) 114 | self.hat_omega_Lasso = self.est_omega_Lasso(self.Y_pre_c, self.Y_pre_t) 115 | self.hat_omega_Ridge = self.est_omega_Ridge(self.Y_pre_c, self.Y_pre_t) 116 | 117 | def did_potentical_outcome(self): 118 | """ 119 | return potential outcome 120 | """ 121 | Y_pre_c = self.Y_pre_c.copy() 122 | Y_pre_t = self.Y_pre_t.copy() 123 | Y_post_c = self.Y_post_c.copy() 124 | Y_post_t = self.Y_post_t.copy() 125 | 126 | if type(Y_pre_t) != pd.DataFrame: 127 | Y_pre_t = pd.DataFrame(Y_pre_t) 128 | 129 | if type(Y_post_t) != pd.DataFrame: 130 | Y_post_t = pd.DataFrame(Y_post_t) 131 | 132 | Y_pre_t["did"] = Y_pre_t.mean(axis=1) 133 | 134 | base_trend = Y_post_c.mean(axis=1) 135 | 136 | Y_post_t["did"] = base_trend + ( 137 | Y_pre_t.mean(axis=1).mean() - Y_pre_c.mean(axis=1).mean() 138 | ) 139 | 140 | return pd.concat([Y_pre_t["did"], Y_post_t["did"]], axis=0) 141 | 142 | def sc_potentical_outcome(self): 143 | return pd.concat([self.Y_pre_c, self.Y_post_c]).dot(self.hat_omega_ADH) 144 | 145 | def sparceReg_potentical_outcome(self, model="ElasticNet"): 146 | Y_pre_c_intercept = self.Y_pre_c.copy() 147 | Y_post_c_intercept = self.Y_post_c.copy() 148 | Y_pre_c_intercept["intercept"] = 1 149 | Y_post_c_intercept["intercept"] = 1 150 | 151 | if model == "ElasticNet": 152 | s_omega = self.hat_omega_ElasticNet 153 | elif model == "Lasso": 154 | s_omega = self.hat_omega_Lasso 155 | elif model == "Ridge": 156 | s_omega = self.hat_omega_Ridge 157 | else: 158 | print(f"model={model} is not supported") 159 | return None 160 | return pd.concat([Y_pre_c_intercept, Y_post_c_intercept]).dot(s_omega) 161 | 162 | def sdid_trajectory(self): 163 | hat_omega = self.hat_omega[:-1] 164 | Y_c = pd.concat([self.Y_pre_c, self.Y_post_c]) 165 | n_features = self.Y_pre_c.shape[1] 166 | start_w = np.repeat(1 / n_features, n_features) 167 | 168 | _intercept = (start_w - hat_omega) @ self.Y_pre_c.T @ self.hat_lambda 169 | 170 | return Y_c.dot(hat_omega) + _intercept 171 | 172 | def sdid_potentical_outcome(self): 173 | Y_pre_c = self.Y_pre_c.copy() 174 | Y_post_c = self.Y_post_c.copy() 175 | hat_omega = self.hat_omega[:-1] 176 | 177 | base_sc = Y_post_c @ hat_omega 178 | pre_treat_base = (self.Y_pre_t.T @ self.hat_lambda).values[0] 179 | pre_control_base = Y_pre_c @ hat_omega @ self.hat_lambda 180 | 181 | pre_outcome = Y_pre_c.dot(hat_omega) 182 | 183 | post_outcome = base_sc + pre_treat_base - pre_control_base 184 | 185 | return pd.concat([pre_outcome, post_outcome], axis=0) 186 | 187 | def sparce_sdid_potentical_outcome(self, model="ElasticNet"): 188 | Y_pre_c_intercept = self.Y_pre_c.copy() 189 | Y_post_c_intercept = self.Y_post_c.copy() 190 | Y_pre_c_intercept["intercept"] = 1 191 | Y_post_c_intercept["intercept"] = 1 192 | 193 | if model == "ElasticNet": 194 | s_omega = self.hat_omega_ElasticNet 195 | elif model == "Lasso": 196 | s_omega = self.hat_omega_Lasso 197 | elif model == "Ridge": 198 | s_omega = self.hat_omega_Ridge 199 | else: 200 | print(f"model={model} is not supported") 201 | return None 202 | 203 | base_sc = Y_post_c_intercept @ s_omega 204 | pre_treat_base = (self.Y_pre_t.T @ self.hat_lambda).values[0] 205 | pre_control_base = Y_pre_c_intercept @ s_omega @ self.hat_lambda 206 | 207 | post_outcome = base_sc + pre_treat_base - pre_control_base 208 | 209 | return pd.concat([Y_pre_c_intercept.dot(s_omega), post_outcome], axis=0) 210 | 211 | def target_y(self): 212 | return self.df.loc[self.pre_term[0] : self.post_term[1], self.treatment].mean( 213 | axis=1 214 | ) 215 | 216 | def estimated_params(self, model="sdid"): 217 | Y_pre_c_intercept = self.Y_pre_c.copy() 218 | Y_post_c_intercept = self.Y_post_c.copy() 219 | Y_pre_c_intercept["intercept"] = 1 220 | Y_post_c_intercept["intercept"] = 1 221 | if model == "sdid": 222 | return ( 223 | pd.DataFrame( 224 | { 225 | "features": Y_pre_c_intercept.columns, 226 | "sdid_weight": np.round(self.hat_omega, 3), 227 | } 228 | ), 229 | pd.DataFrame( 230 | { 231 | "time": Y_pre_c_intercept.index, 232 | "sdid_weight": np.round(self.hat_lambda, 3), 233 | } 234 | ), 235 | ) 236 | elif model == "sc": 237 | return pd.DataFrame( 238 | { 239 | "features": self.Y_pre_c.columns, 240 | "sc_weight": np.round(self.hat_omega_ADH, 3), 241 | } 242 | ) 243 | elif model == "ElasticNet": 244 | return pd.DataFrame( 245 | { 246 | "features": Y_pre_c_intercept.columns, 247 | "ElasticNet_weight": np.round(self.hat_omega_ElasticNet, 3), 248 | } 249 | ) 250 | elif model == "Lasso": 251 | return pd.DataFrame( 252 | { 253 | "features": Y_pre_c_intercept.columns, 254 | "Lasso_weight": np.round(self.hat_omega_Lasso, 3), 255 | } 256 | ) 257 | elif model == "Ridge": 258 | return pd.DataFrame( 259 | { 260 | "features": Y_pre_c_intercept.columns, 261 | "Ridge_weight": np.round(self.hat_omega_Ridge, 3), 262 | } 263 | ) 264 | else: 265 | return None 266 | 267 | def hat_tau(self, model="sdid"): 268 | """ 269 | return ATT 270 | """ 271 | result = pd.DataFrame({"actual_y": self.target_y()}) 272 | post_actural_treat = result.loc[self.post_term[0] :, "actual_y"].mean() 273 | 274 | if model == "sdid": 275 | result["sdid"] = self.sdid_trajectory() 276 | 277 | pre_sdid = result["sdid"].head(len(self.hat_lambda)) @ self.hat_lambda 278 | post_sdid = result.loc[self.post_term[0] :, "sdid"].mean() 279 | 280 | pre_treat = (self.Y_pre_t.T @ self.hat_lambda).values[0] 281 | counterfuctual_post_treat = pre_treat + (post_sdid - pre_sdid) 282 | 283 | elif model == "sc": 284 | result["sc"] = self.sc_potentical_outcome() 285 | post_sc = result.loc[self.post_term[0] :, "sc"].mean() 286 | counterfuctual_post_treat = post_sc 287 | 288 | elif model == "did": 289 | Y_pre_t = self.Y_pre_t.copy() 290 | Y_post_t = self.Y_post_t.copy() 291 | if type(Y_pre_t) != pd.DataFrame: 292 | Y_pre_t = pd.DataFrame(Y_pre_t) 293 | 294 | if type(Y_post_t) != pd.DataFrame: 295 | Y_post_t = pd.DataFrame(Y_post_t) 296 | 297 | # actural treat 298 | post_actural_treat = ( 299 | Y_post_t.mean(axis=1).mean() - Y_pre_t.mean(axis=1).mean() 300 | ) 301 | counterfuctual_post_treat = ( 302 | self.Y_post_c.mean(axis=1).mean() - self.Y_pre_c.mean(axis=1).mean() 303 | ) 304 | 305 | return post_actural_treat - counterfuctual_post_treat 306 | 307 | def cal_se(self, algo="placebo", replications=200): 308 | 309 | sdid_var, sc_var, did_var = self.estimate_variance( 310 | algo=algo, replications=replications 311 | ) 312 | 313 | self.sdid_se = np.sqrt(sdid_var) 314 | self.sc_se = np.sqrt(sc_var) 315 | self.did_se = np.sqrt(did_var) 316 | 317 | 318 | if __name__ == "__main__": 319 | from sample_data import fetch_CaliforniaSmoking 320 | 321 | df = fetch_CaliforniaSmoking() 322 | 323 | PRE_TEREM = [1970, 1979] 324 | POST_TEREM = [1980, 1988] 325 | TREATMENT = ["California"] 326 | 327 | sdid = SynthDID(df, PRE_TEREM, POST_TEREM, TREATMENT) 328 | sdid.fit() 329 | sdid.plot(model="sdid") 330 | -------------------------------------------------------------------------------- /synthdid/optimizer.py: -------------------------------------------------------------------------------- 1 | from tqdm import tqdm 2 | import pandas as pd 3 | import numpy as np 4 | from sklearn.preprocessing import StandardScaler 5 | from scipy.optimize import fmin_slsqp 6 | from toolz import partial 7 | from sklearn.model_selection import KFold, TimeSeriesSplit, RepeatedKFold 8 | from sklearn.linear_model import ElasticNetCV, LassoCV, RidgeCV 9 | from bayes_opt import BayesianOptimization 10 | 11 | 12 | class Optimize(object): 13 | #### 14 | # Synthetic Difference in Differences (SDID) 15 | #### 16 | def est_zeta(self, Y_pre_c) -> float: 17 | """ 18 | # SDID 19 | Parameter to adjust the L2 penalty term 20 | """ 21 | return (self.n_treat * self.n_post_term) ** (1 / 4) * np.std( 22 | Y_pre_c.diff().dropna().values 23 | ) 24 | 25 | def est_omega(self, Y_pre_c, Y_pre_t, zeta): 26 | """ 27 | # SDID 28 | estimating omega 29 | """ 30 | Y_pre_t = Y_pre_t.copy() 31 | n_features = Y_pre_c.shape[1] 32 | nrow = Y_pre_c.shape[0] 33 | 34 | _w = np.repeat(1 / n_features, n_features) 35 | _w0 = 1 36 | 37 | start_w = np.append(_w, _w0) 38 | 39 | if type(Y_pre_t) == pd.core.frame.DataFrame: 40 | Y_pre_t = Y_pre_t.mean(axis=1) 41 | 42 | # Required to have non negative values 43 | max_bnd = abs(Y_pre_t.mean()) * 2 44 | w_bnds = tuple( 45 | (0, 1) if i < n_features else (max_bnd * -1, max_bnd) 46 | for i in range(n_features + 1) 47 | ) 48 | 49 | caled_w = fmin_slsqp( 50 | partial(self.l2_loss, X=Y_pre_c, y=Y_pre_t, zeta=zeta, nrow=nrow), 51 | start_w, 52 | f_eqcons=lambda x: np.sum(x[:n_features]) - 1, 53 | bounds=w_bnds, 54 | disp=False, 55 | ) 56 | 57 | return caled_w 58 | 59 | def est_lambda(self, Y_pre_c, Y_post_c): 60 | """ 61 | # SDID 62 | estimating lambda 63 | """ 64 | Y_pre_c_T = Y_pre_c.T 65 | Y_post_c_T = Y_post_c.T 66 | 67 | n_pre_term = Y_pre_c_T.shape[1] 68 | 69 | _lambda = np.repeat(1 / n_pre_term, n_pre_term) 70 | _lambda0 = 1 71 | 72 | start_lambda = np.append(_lambda, _lambda0) 73 | 74 | if type(Y_post_c_T) == pd.core.frame.DataFrame: 75 | Y_post_c_T = Y_post_c_T.mean(axis=1) 76 | 77 | max_bnd = abs(Y_post_c_T.mean()) * 2 78 | lambda_bnds = tuple( 79 | (0, 1) if i < n_pre_term else (max_bnd * -1, max_bnd) 80 | for i in range(n_pre_term + 1) 81 | ) 82 | 83 | caled_lambda = fmin_slsqp( 84 | partial(self.l2_loss, X=Y_pre_c_T, y=Y_post_c_T, zeta=0, nrow=0), 85 | start_lambda, 86 | f_eqcons=lambda x: np.sum(x[:n_pre_term]) - 1, 87 | bounds=lambda_bnds, 88 | disp=False, 89 | ) 90 | 91 | return caled_lambda[:n_pre_term] 92 | 93 | def l2_loss(self, W, X, y, zeta, nrow) -> float: 94 | """ 95 | Loss function with L2 penalty 96 | """ 97 | if type(y) == pd.core.frame.DataFrame: 98 | y = y.mean(axis=1) 99 | _X = X.copy() 100 | _X["intersept"] = 1 101 | return np.sum((y - _X.dot(W)) ** 2) + nrow * zeta ** 2 * np.sum(W[:-1] ** 2) 102 | 103 | #### 104 | # Synthetic Control Method (SC) 105 | #### 106 | def rmse_loss(self, W, X, y, intersept=True) -> float: 107 | if type(y) == pd.core.frame.DataFrame: 108 | y = y.mean(axis=1) 109 | _X = X.copy() 110 | if intersept: 111 | _X["intersept"] = 1 112 | return np.mean(np.sqrt((y - _X.dot(W)) ** 2)) 113 | 114 | def rmse_loss_with_V(self, W, V, X, y) -> float: 115 | if type(y) == pd.core.frame.DataFrame: 116 | y = y.mean(axis=1) 117 | _rss = (y - X.dot(W)) ** 2 118 | 119 | _n = len(y) 120 | _importance = np.zeros((_n, _n)) 121 | 122 | np.fill_diagonal(_importance, V) 123 | 124 | return np.sum(_importance @ _rss) 125 | 126 | def _v_loss(self, V, X, y, return_loss=True): 127 | Y_pre_t = self.Y_pre_t.copy() 128 | 129 | n_features = self.Y_pre_c.shape[1] 130 | _w = np.repeat(1 / n_features, n_features) 131 | 132 | if type(Y_pre_t) == pd.core.frame.DataFrame: 133 | Y_pre_t = Y_pre_t.mean(axis=1) 134 | 135 | w_bnds = tuple((0, 1) for i in range(n_features)) 136 | _caled_w = fmin_slsqp( 137 | partial(self.rmse_loss_with_V, V=V, X=X, y=y), 138 | _w, 139 | f_eqcons=lambda x: np.sum(x) - 1, 140 | bounds=w_bnds, 141 | disp=False, 142 | ) 143 | if return_loss: 144 | return self.rmse_loss(_caled_w, self.Y_pre_c, Y_pre_t, intersept=False) 145 | else: 146 | return _caled_w 147 | 148 | def estimate_v(self, additional_X, additional_y): 149 | _len = len(additional_X) 150 | _v = np.repeat(1 / _len, _len) 151 | 152 | caled_v = fmin_slsqp( 153 | partial(self._v_loss, X=additional_X, y=additional_y), 154 | _v, 155 | f_eqcons=lambda x: np.sum(x) - 1, 156 | bounds=tuple((0, 1) for i in range(_len)), 157 | disp=False, 158 | ) 159 | return caled_v 160 | 161 | def est_omega_ADH( 162 | self, Y_pre_c, Y_pre_t, additional_X=pd.DataFrame(), additional_y=pd.DataFrame() 163 | ): 164 | """ 165 | # SC 166 | estimating omega for synthetic control method (not for synthetic diff.-in-diff.) 167 | """ 168 | Y_pre_t = Y_pre_t.copy() 169 | 170 | n_features = Y_pre_c.shape[1] 171 | nrow = Y_pre_c.shape[0] 172 | 173 | _w = np.repeat(1 / n_features, n_features) 174 | 175 | if type(Y_pre_t) == pd.core.frame.DataFrame: 176 | Y_pre_t = Y_pre_t.mean(axis=1) 177 | 178 | # Required to have non negative values 179 | w_bnds = tuple((0, 1) for i in range(n_features)) 180 | 181 | if len(additional_X) == 0: 182 | caled_w = fmin_slsqp( 183 | partial(self.rmse_loss, X=Y_pre_c, y=Y_pre_t, intersept=False), 184 | _w, 185 | f_eqcons=lambda x: np.sum(x) - 1, 186 | bounds=w_bnds, 187 | disp=False, 188 | ) 189 | 190 | return caled_w 191 | else: 192 | assert additional_X.shape[1] == Y_pre_c.shape[1] 193 | if type(additional_y) == pd.core.frame.DataFrame: 194 | additional_y = additional_y.mean(axis=1) 195 | 196 | # normalized 197 | temp_df = pd.concat([additional_X, additional_y], axis=1) 198 | ss = StandardScaler() 199 | ss_df = pd.DataFrame( 200 | ss.fit_transform(temp_df), columns=temp_df.columns, index=temp_df.index 201 | ) 202 | 203 | ss_X = ss_df.iloc[:, :-1] 204 | ss_y = ss_df.iloc[:, -1] 205 | 206 | add_X = pd.concat([Y_pre_c, ss_X]) 207 | add_y = pd.concat([Y_pre_t, ss_y]) 208 | 209 | self.caled_v = self.estimate_v(additional_X=add_X, additional_y=add_y) 210 | 211 | return self._v_loss(self.caled_v, X=add_X, y=add_y, return_loss=False) 212 | 213 | ##### 214 | # cv search for zeta 215 | #### 216 | 217 | def _zeta_given_cv_loss_inverse(self, zeta, cv=5, split_type="KFold"): 218 | return -1 * self._zeta_given_cv_loss(zeta, cv, split_type)[0] 219 | 220 | def _zeta_given_cv_loss(self, zeta, cv=5, split_type="KFold"): 221 | nrow = self.Y_pre_c.shape[0] 222 | if split_type == "KFold": 223 | kf = KFold(n_splits=cv, random_state=self.random_seed) 224 | elif split_type == "TimeSeriesSplit": 225 | kf = TimeSeriesSplit(n_splits=cv) 226 | elif split_type == "RepeatedKFold": 227 | _cv = max(2, int(cv / 2)) 228 | kf = RepeatedKFold( 229 | n_splits=_cv, n_repeats=_cv, random_state=self.random_seed 230 | ) 231 | 232 | loss_result = [] 233 | nf_result = [] 234 | for train_index, test_index in kf.split(self.Y_pre_c, self.Y_pre_t): 235 | train_w = self.est_omega( 236 | self.Y_pre_c.iloc[train_index], self.Y_pre_t.iloc[train_index], zeta 237 | ) 238 | 239 | nf_result.append(np.sum(np.round(np.abs(train_w), 3) > 0) - 1) 240 | 241 | loss_result.append( 242 | self.rmse_loss( 243 | train_w, 244 | self.Y_pre_c.iloc[test_index], 245 | self.Y_pre_t.iloc[test_index], 246 | ) 247 | ) 248 | return np.mean(loss_result), np.mean(nf_result) 249 | 250 | def grid_search_zeta( 251 | self, cv=5, n_candidate=20, candidate_zata=[], split_type="KFold" 252 | ): 253 | """ 254 | Search for zeta using grid search instead of theoretical values 255 | """ 256 | 257 | if len(candidate_zata) == 0: 258 | 259 | for _z in np.linspace(0.1, self.base_zeta * 2, n_candidate): 260 | candidate_zata.append(_z) 261 | 262 | candidate_zata.append(self.base_zeta) 263 | candidate_zata.append(0) 264 | 265 | candidate_zata = sorted(candidate_zata) 266 | 267 | result_loss_dict = {} 268 | result_nf_dict = {} 269 | 270 | print("cv: zeta") 271 | for _zeta in tqdm(candidate_zata): 272 | result_loss_dict[_zeta], result_nf_dict[_zeta] = self._zeta_given_cv_loss( 273 | _zeta, cv=cv, split_type=split_type 274 | ) 275 | 276 | loss_sorted = sorted(result_loss_dict.items(), key=lambda x: x[1]) 277 | 278 | return loss_sorted[0] 279 | 280 | def bayes_opt_zeta( 281 | self, 282 | cv=5, 283 | init_points=5, 284 | n_iter=5, 285 | zeta_max=None, 286 | zeta_min=None, 287 | split_type="KFold", 288 | ): 289 | """ 290 | Search for zeta using Bayesian Optimization instead of theoretical values 291 | """ 292 | if zeta_max == None: 293 | zeta_max = self.base_zeta * 1.02 294 | zeta_max2 = self.base_zeta * 2 295 | 296 | if zeta_min == None: 297 | zeta_min = self.base_zeta * 0.98 298 | zeta_min2 = 0.01 299 | 300 | pbounds = {"zeta": (zeta_min, zeta_max)} 301 | 302 | optimizer = BayesianOptimization( 303 | f=partial(self._zeta_given_cv_loss_inverse, cv=cv, split_type=split_type), 304 | pbounds=pbounds, 305 | random_state=self.random_seed, 306 | ) 307 | 308 | optimizer.maximize( 309 | init_points=2, 310 | n_iter=2, 311 | ) 312 | 313 | optimizer.set_bounds(new_bounds={"zeta": (zeta_min2, zeta_max2)}) 314 | 315 | optimizer.maximize( 316 | init_points=init_points, 317 | n_iter=n_iter, 318 | ) 319 | 320 | optimizer.max["params"]["zeta"] 321 | 322 | return (optimizer.max["params"]["zeta"], optimizer.max["target"] * -1) 323 | 324 | ##### 325 | # The following is for sparse estimation 326 | #### 327 | def est_omega_ElasticNet(self, Y_pre_c, Y_pre_t): 328 | Y_pre_t = Y_pre_t.copy() 329 | 330 | if type(Y_pre_t) == pd.core.frame.DataFrame: 331 | Y_pre_t = Y_pre_t.mean(axis=1) 332 | # Y_pre_t.columns = "treatment_group" 333 | 334 | regr = ElasticNetCV(cv=5, random_state=0) 335 | regr.fit(Y_pre_c, Y_pre_t) 336 | 337 | self.elastic_net_alpha = regr.alpha_ 338 | 339 | caled_w = regr.coef_ 340 | 341 | return np.append(caled_w, regr.intercept_) 342 | 343 | def est_omega_Lasso(self, Y_pre_c, Y_pre_t): 344 | Y_pre_t = Y_pre_t.copy() 345 | 346 | if type(Y_pre_t) == pd.core.frame.DataFrame: 347 | Y_pre_t = Y_pre_t.mean(axis=1) 348 | 349 | regr = LassoCV(cv=5, random_state=0) 350 | regr.fit(Y_pre_c, Y_pre_t) 351 | 352 | self.lasso_alpha = regr.alpha_ 353 | 354 | caled_w = regr.coef_ 355 | 356 | return np.append(caled_w, regr.intercept_) 357 | 358 | def est_omega_Ridge(self, Y_pre_c, Y_pre_t): 359 | Y_pre_t = Y_pre_t.copy() 360 | 361 | if type(Y_pre_t) == pd.core.frame.DataFrame: 362 | Y_pre_t = Y_pre_t.mean(axis=1) 363 | 364 | regr = RidgeCV(cv=5) 365 | regr.fit(Y_pre_c, Y_pre_t) 366 | 367 | self.ridge_alpha = regr.alpha_ 368 | 369 | caled_w = regr.coef_ 370 | 371 | return np.append(caled_w, regr.intercept_) 372 | -------------------------------------------------------------------------------- /synthdid/plot.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import numpy as np 3 | import seaborn as sns 4 | import matplotlib.pyplot as plt 5 | 6 | 7 | class Plot(object): 8 | def plot(self, model="sdid", figsize=(10, 7)): 9 | 10 | result = pd.DataFrame({"actual_y": self.target_y()}) 11 | post_actural_treat = result.loc[self.post_term[0] :, "actual_y"].mean() 12 | post_point = np.mean(self.Y_post_c.index) 13 | 14 | if model == "sdid": 15 | result["sdid"] = self.sdid_trajectory() 16 | time_result = pd.DataFrame( 17 | { 18 | "time": self.Y_pre_c.index, 19 | "sdid_weight": np.round(self.hat_lambda, 3), 20 | } 21 | ) 22 | 23 | pre_point = self.Y_pre_c.index @ self.hat_lambda 24 | 25 | pre_sdid = result["sdid"].head(len(self.hat_lambda)) @ self.hat_lambda 26 | post_sdid = result.loc[self.post_term[0] :, "sdid"].mean() 27 | 28 | pre_treat = (self.Y_pre_t.T @ self.hat_lambda).values[0] 29 | counterfuctual_post_treat = pre_treat + (post_sdid - pre_sdid) 30 | 31 | fig, ax = plt.subplots(figsize=figsize) 32 | 33 | result["actual_y"].plot( 34 | ax=ax, color="blue", linewidth=1, label="treatment group", alpha=0.6 35 | ) 36 | result["sdid"].plot( 37 | ax=ax, color="red", linewidth=1, label="syntetic control", alpha=0.6 38 | ) 39 | ax.plot( 40 | [pre_point, post_point], 41 | [pre_sdid, post_sdid], 42 | label="", 43 | marker="o", 44 | color="red", 45 | ) 46 | ax.plot( 47 | [pre_point, post_point], 48 | [pre_treat, post_actural_treat], 49 | label="", 50 | marker="o", 51 | color="blue", 52 | ) 53 | ax.plot( 54 | [pre_point, post_point], 55 | [pre_treat, counterfuctual_post_treat], 56 | label="", 57 | marker="o", 58 | color="blue", 59 | linewidth=1, 60 | linestyle="dashed", 61 | alpha=0.3, 62 | ) 63 | 64 | ax.axvline( 65 | x=(self.pre_term[1] + self.post_term[0]) * 0.5, 66 | linewidth=1, 67 | linestyle="dashed", 68 | color="black", 69 | alpha=0.3, 70 | ) 71 | 72 | ax2 = ax.twinx() 73 | ax2.bar( 74 | time_result["time"], 75 | time_result["sdid_weight"], 76 | color="#ff7f00", 77 | label="time weight", 78 | width=1.0, 79 | alpha=0.6, 80 | ) 81 | ax2.set_ylim(0, 3) 82 | ax2.axis("off") 83 | ax.set_title( 84 | f"Synthetic Difference in Differences : tau {round( post_actural_treat - counterfuctual_post_treat,4)}" 85 | ) 86 | ax.legend() 87 | plt.show() 88 | 89 | elif model == "sc": 90 | result["sc"] = self.sc_potentical_outcome() 91 | 92 | pre_sc = result.loc[: self.pre_term[1], "sc"].mean() 93 | post_sc = result.loc[self.post_term[0] :, "sc"].mean() 94 | 95 | pre_treat = self.Y_pre_t.mean() 96 | counterfuctual_post_treat = post_sc 97 | 98 | fig, ax = plt.subplots(figsize=figsize) 99 | 100 | result["actual_y"].plot( 101 | ax=ax, color="blue", linewidth=1, label="treatment group", alpha=0.6 102 | ) 103 | result["sc"].plot( 104 | ax=ax, color="red", linewidth=1, label="syntetic control", alpha=0.6 105 | ) 106 | 107 | ax.annotate( 108 | "", 109 | xy=(post_point, post_actural_treat), 110 | xytext=(post_point, counterfuctual_post_treat), 111 | arrowprops=dict(arrowstyle="-|>", color="black"), 112 | ) 113 | 114 | ax.axvline( 115 | x=(self.pre_term[1] + self.post_term[0]) * 0.5, 116 | linewidth=1, 117 | linestyle="dashed", 118 | color="black", 119 | alpha=0.3, 120 | ) 121 | ax.set_title( 122 | f"Synthetic Control Method : tau {round( post_actural_treat - counterfuctual_post_treat,4)}" 123 | ) 124 | ax.legend() 125 | plt.show() 126 | 127 | elif model == "did": 128 | Y_pre_t = self.Y_pre_t.copy() 129 | Y_post_t = self.Y_post_t.copy() 130 | if type(Y_pre_t) != pd.DataFrame: 131 | Y_pre_t = pd.DataFrame(Y_pre_t) 132 | 133 | if type(Y_post_t) != pd.DataFrame: 134 | Y_post_t = pd.DataFrame(Y_post_t) 135 | 136 | result["did"] = self.df[self.control].mean(axis=1) 137 | pre_point = np.mean(self.Y_pre_c.index) 138 | 139 | pre_did = result.loc[: self.pre_term[1], "did"].mean() 140 | post_did = result.loc[self.post_term[0] :, "did"].mean() 141 | 142 | pre_treat = Y_pre_t.mean(axis=1).mean() 143 | counterfuctual_post_treat = pre_treat + (post_did - pre_did) 144 | 145 | fig, ax = plt.subplots(figsize=figsize) 146 | 147 | result["actual_y"].plot( 148 | ax=ax, color="blue", linewidth=1, label="treatment group", alpha=0.6 149 | ) 150 | result["did"].plot( 151 | ax=ax, color="red", linewidth=1, label="control", alpha=0.6 152 | ) 153 | 154 | ax.plot( 155 | [pre_point, post_point], 156 | [pre_did, post_did], 157 | label="", 158 | marker="o", 159 | color="red", 160 | ) 161 | ax.plot( 162 | [pre_point, post_point], 163 | [pre_treat, post_actural_treat], 164 | label="", 165 | marker="o", 166 | color="blue", 167 | ) 168 | ax.plot( 169 | [pre_point, post_point], 170 | [pre_treat, counterfuctual_post_treat], 171 | label="", 172 | marker="o", 173 | color="blue", 174 | linewidth=1, 175 | linestyle="dashed", 176 | alpha=0.3, 177 | ) 178 | 179 | ax.axvline( 180 | x=(self.pre_term[1] + self.post_term[0]) * 0.5, 181 | linewidth=1, 182 | linestyle="dashed", 183 | color="black", 184 | alpha=0.3, 185 | ) 186 | ax.set_title( 187 | f"Difference in Differences : tau {round( post_actural_treat - counterfuctual_post_treat,4)}" 188 | ) 189 | ax.legend() 190 | plt.show() 191 | 192 | else: 193 | print(f"sorry: {model} is not yet available.") 194 | 195 | def comparison_plot(self, model="all", figsize=(10, 7)): 196 | result = pd.DataFrame({"actual_y": self.target_y()}) 197 | 198 | result["did"] = self.did_potentical_outcome() 199 | result["sc"] = self.sc_potentical_outcome() 200 | result["sdid"] = self.sdid_potentical_outcome() 201 | 202 | result = result.loc[self.post_term[0] : self.post_term[1]] 203 | 204 | fig, ax = plt.subplots(figsize=figsize) 205 | 206 | result["actual_y"].plot(ax=ax, color="black", linewidth=1, label="actual_y") 207 | result["sdid"].plot(ax=ax, label="Synthetic Difference in Differences") 208 | result["sc"].plot( 209 | ax=ax, linewidth=1, linestyle="dashed", label="Synthetic Control" 210 | ) 211 | result["did"].plot( 212 | ax=ax, linewidth=1, linestyle="dashed", label="Difference in Differences" 213 | ) 214 | 215 | plt.legend( 216 | bbox_to_anchor=(1.05, 1), loc="upper left", borderaxespad=0, fontsize=18 217 | ) 218 | plt.show() 219 | -------------------------------------------------------------------------------- /synthdid/sample_data.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | 3 | 4 | def fetch_CaliforniaSmoking() -> pd.DataFrame: 5 | """ 6 | This data is from https://web.stanford.edu/~jhain/synthpage.html 7 | [Retun] 8 | pd.DataFrame 9 | """ 10 | _raw = pd.read_csv("../sample_data/MLAB_data.txt", sep="\t", header=None) 11 | 12 | _raw.columns = [ 13 | "Alabama", 14 | "Arkansas", 15 | "Colorado", 16 | "Connecticut", 17 | "Delaware", 18 | "Georgia", 19 | "Idaho", 20 | "Illinois", 21 | "Indiana", 22 | "Iowa", 23 | "Kansas", 24 | "Kentucky", 25 | "Louisiana", 26 | "Maine", 27 | "Minnesota", 28 | "Mississippi", 29 | "Missouri", 30 | "Montana", 31 | "Nebraska", 32 | "Nevada", 33 | "New Hampshire", 34 | "New Mexico", 35 | "North Carolina", 36 | "North Dakota", 37 | "Ohio", 38 | "Oklahoma", 39 | "Pennsylvania", 40 | "Rhode Island", 41 | "South Carolina", 42 | "South Dakota", 43 | "Tennessee", 44 | "Texas", 45 | "Utah", 46 | "Vermont", 47 | "Virginia", 48 | "West Virginia", 49 | "Wisconsin", 50 | "Wyoming", 51 | "California", 52 | ] 53 | 54 | _raw.index = [i for i in range(1962, 2001)] 55 | 56 | return _raw.loc[1970:] 57 | 58 | -------------------------------------------------------------------------------- /synthdid/summary.py: -------------------------------------------------------------------------------- 1 | class Summary(object): 2 | def summary(self, model="sdid"): 3 | if model == "sdid": 4 | att = self.hat_tau(model="sdid") 5 | print("------------------------------------------------------------") 6 | print("Syntetic Difference in Differences") 7 | print("") 8 | if self.sdid_se != None: 9 | print(f"point estimate: {att:.3f} ({self.sdid_se:.3f})") 10 | print( 11 | f"95% CI ({att - 1.96*self.sdid_se :.3f}, {att + 1.96*self.sdid_se:.3f})" 12 | ) 13 | else: 14 | print(f"point estimate: {att:.3f}") 15 | elif model == "sc": 16 | att = self.hat_tau(model="sc") 17 | print("------------------------------------------------------------") 18 | print("Syntetic Control Method") 19 | print("") 20 | if self.sdid_se != None: 21 | print(f"point estimate: {att:.3f} ({self.sc_se:.3f})") 22 | print( 23 | f"95% CI ({att - 1.96*self.sc_se :.3f}, {att + 1.96*self.sc_se:.3f})" 24 | ) 25 | else: 26 | print(f"point estimate: {att:.3f}") 27 | 28 | elif model == "did": 29 | att = self.hat_tau(model="did") 30 | print("------------------------------------------------------------") 31 | print("Difference in Differences") 32 | print("") 33 | if self.sdid_se != None: 34 | print(f"point estimate: {att:.3f} ({self.did_se:.3f})") 35 | print( 36 | f"95% CI ({att - 1.96*self.did_se :.3f}, {att + 1.96*self.did_se:.3f})" 37 | ) 38 | else: 39 | print(f"point estimate: {att:.3f}") 40 | -------------------------------------------------------------------------------- /synthdid/variance.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import numpy as np 3 | from tqdm import tqdm 4 | 5 | 6 | class Variance(object): 7 | def estimate_variance(self, algo="placebo", replications=200): 8 | """ 9 | # algo 10 | - placebo 11 | ## The following algorithms are omitted because they are not practical. 12 | - bootstrap 13 | - jackknife 14 | """ 15 | 16 | if algo == "placebo": 17 | Y_pre_c = self.Y_pre_c.copy() 18 | Y_post_c = self.Y_post_c.copy() 19 | assert self.n_treat < Y_pre_c.shape[1] 20 | control_names = Y_pre_c.columns 21 | 22 | result_tau_sdid = [] 23 | result_tau_sc = [] 24 | result_tau_did = [] 25 | for i in tqdm(range(replications)): 26 | # setup 27 | np.random.seed(seed=self.random_seed + i) 28 | placebo_t = np.random.choice(control_names, self.n_treat, replace=False) 29 | placebo_c = [col for col in control_names if col not in placebo_t] 30 | pla_Y_pre_t = Y_pre_c[placebo_t] 31 | pla_Y_post_t = Y_post_c[placebo_t] 32 | pla_Y_pre_c = Y_pre_c[placebo_c] 33 | pla_Y_post_c = Y_post_c[placebo_c] 34 | 35 | pla_result = pd.DataFrame( 36 | { 37 | "pla_actual_y": pd.concat([pla_Y_pre_t, pla_Y_post_t]).mean( 38 | axis=1 39 | ) 40 | } 41 | ) 42 | post_placebo_treat = pla_result.loc[ 43 | self.post_term[0] :, "pla_actual_y" 44 | ].mean() 45 | 46 | # estimation 47 | ## sdid 48 | pla_zeta = self.est_zeta(pla_Y_pre_c) 49 | 50 | pla_hat_omega = self.est_omega(pla_Y_pre_c, pla_Y_pre_t, pla_zeta) 51 | pla_hat_lambda = self.est_lambda(pla_Y_pre_c, pla_Y_post_c) 52 | ## sc 53 | pla_hat_omega_ADH = self.est_omega_ADH(pla_Y_pre_c, pla_Y_pre_t) 54 | 55 | # prediction 56 | ## sdid 57 | pla_hat_omega = pla_hat_omega[:-1] 58 | pla_Y_c = pd.concat([pla_Y_pre_c, pla_Y_post_c]) 59 | n_features = pla_Y_pre_c.shape[1] 60 | start_w = np.repeat(1 / n_features, n_features) 61 | 62 | _intercept = (start_w - pla_hat_omega) @ pla_Y_pre_c.T @ pla_hat_lambda 63 | 64 | pla_result["sdid"] = pla_Y_c.dot(pla_hat_omega) + _intercept 65 | 66 | ## sc 67 | pla_result["sc"] = pla_Y_c.dot(pla_hat_omega_ADH) 68 | 69 | # cal tau 70 | ## sdid 71 | pre_sdid = pla_result["sdid"].head(len(pla_hat_lambda)) @ pla_hat_lambda 72 | post_sdid = pla_result.loc[self.post_term[0] :, "sdid"].mean() 73 | 74 | pre_treat = (pla_Y_pre_t.T @ pla_hat_lambda).values[0] 75 | sdid_counterfuctual_post_treat = pre_treat + (post_sdid - pre_sdid) 76 | 77 | result_tau_sdid.append( 78 | post_placebo_treat - sdid_counterfuctual_post_treat 79 | ) 80 | 81 | ## sc 82 | sc_counterfuctual_post_treat = pla_result.loc[ 83 | self.post_term[0] :, "sc" 84 | ].mean() 85 | result_tau_sc.append(post_placebo_treat - sc_counterfuctual_post_treat) 86 | 87 | # did 88 | did_post_actural_treat = ( 89 | post_placebo_treat 90 | - pla_result.loc[: self.pre_term[1], "pla_actual_y"].mean() 91 | ) 92 | did_counterfuctual_post_treat = ( 93 | pla_Y_post_c.mean(axis=1).mean() - pla_Y_pre_c.mean(axis=1).mean() 94 | ) 95 | result_tau_did.append( 96 | did_post_actural_treat - did_counterfuctual_post_treat 97 | ) 98 | 99 | return ( 100 | np.var(result_tau_sdid), 101 | np.var(result_tau_sc), 102 | np.var(result_tau_did), 103 | ) 104 | -------------------------------------------------------------------------------- /test/test_data/lambda_CalifolinaSmoking.csv: -------------------------------------------------------------------------------- 1 | year,lambda_sdid 2 | 1970,0 3 | 1971,0 4 | 1972,0 5 | 1973,0 6 | 1974,0 7 | 1975,0 8 | 1976,0 9 | 1977,0 10 | 1978,0 11 | 1979,0 12 | 1980,0 13 | 1981,0 14 | 1982,0 15 | 1983,0 16 | 1984,0 17 | 1985,0 18 | 1986,0.366 19 | 1987,0.206 20 | 1988,0.427 -------------------------------------------------------------------------------- /test/test_data/omega_CalifolinaSmoking.csv: -------------------------------------------------------------------------------- 1 | state,omega_ADH,omega_sdid 2 | Alabama,0,0 3 | Arkansas,0,0.03 4 | Colorado,0.013,0.058 5 | Connecticut,0.104,0.078 6 | Delaware,0.004,0.07 7 | Georgia,0,0.002 8 | Idaho,0,0.031 9 | Illinois,0,0.053 10 | Indiana,0,0.01 11 | Iowa,0,0.026 12 | Kansas,0,0.022 13 | Kentucky,0,0 14 | Louisiana,0,0 15 | Maine,0,0.028 16 | Minnesota,0,0.039 17 | Mississippi,0,0 18 | Missouri,0,0.008 19 | Montana,0.232,0.045 20 | Nebraska,0,0.048 21 | Nevada,0.204,0.124 22 | New Hampshire,0.045,0.105 23 | New Mexico,0,0.041 24 | North Carolina,0,0.033 25 | North Dakota,0,0 26 | Ohio,0,0.031 27 | Oklahoma,0,0 28 | Pennsylvania,0,0.015 29 | Rhode Island,0,0.001 30 | South Carolina,0,0 31 | South Dakota,0,0.004 32 | Tennessee,0,0 33 | Texas,0,0.01 34 | Utah,0.396,0.042 35 | Vermont,0,0 36 | Virginia,0,0 37 | West Virginia,0,0.034 38 | Wisconsin,0,0.037 39 | Wyoming,0,0.001 -------------------------------------------------------------------------------- /test/test_model.py: -------------------------------------------------------------------------------- 1 | from operator import index 2 | from pyexpat import model 3 | import sys 4 | import os 5 | from scipy.stats import spearmanr 6 | 7 | sys.path.append(os.path.abspath("../")) 8 | 9 | import pandas as pd 10 | import numpy as np 11 | import pytest 12 | 13 | from synthdid.model import SynthDID 14 | from synthdid.sample_data import fetch_CaliforniaSmoking 15 | 16 | 17 | class TestModelSynth(object): 18 | def test_params_with_originalpaper(self): 19 | """ 20 | Original Paper (see: Arkhangelsky, Dmitry, et al. Synthetic difference in differences. No. w25532. National Bureau of Economic Research, 2019. https://arxiv.org/abs/1812.09970) 21 | """ 22 | test_df = fetch_CaliforniaSmoking() 23 | test_omega = pd.read_csv("test_data/omega_CalifolinaSmoking.csv") 24 | test_lambda = pd.read_csv("test_data/lambda_CalifolinaSmoking.csv") 25 | PRE_TEREM = [1970, 1988] 26 | POST_TEREM = [1989, 2000] 27 | 28 | TREATMENT = ["California"] 29 | 30 | sdid = SynthDID(test_df, PRE_TEREM, POST_TEREM, TREATMENT) 31 | 32 | sdid.fit(zeta_type="base") 33 | 34 | hat_omega_sdid, hat_lambda_sdid = sdid.estimated_params() 35 | hat_omega = sdid.estimated_params(model="sc") 36 | 37 | omega_result = pd.merge( 38 | test_omega, hat_omega_sdid, left_on="state", right_on="features", how="left" 39 | ) 40 | omega_result = pd.merge( 41 | omega_result, hat_omega, left_on="state", right_on="features", how="left" 42 | ) 43 | omega_result["random"] = 1 / len(omega_result) 44 | 45 | error_random_omega = np.sqrt( 46 | omega_result.eval("omega_sdid - random") ** 2 47 | ).sum() 48 | 49 | # Classic SC Result 50 | error_sc_omega = np.sqrt(omega_result.eval("omega_ADH - sc_weight") ** 2).sum() 51 | assert error_sc_omega < error_random_omega 52 | 53 | adh_corr, _p = spearmanr(omega_result["omega_ADH"], omega_result["sc_weight"]) 54 | assert adh_corr >= 0.9 55 | 56 | # SDID Result 57 | error_sdid_omega = np.sqrt( 58 | omega_result.eval("omega_sdid - sdid_weight") ** 2 59 | ).sum() 60 | assert error_sdid_omega < error_random_omega 61 | 62 | sdid_corr, _p = spearmanr( 63 | omega_result["omega_sdid"], omega_result["sdid_weight"] 64 | ) 65 | assert sdid_corr >= 0.9 66 | 67 | lambda_result = pd.merge( 68 | test_lambda, hat_lambda_sdid, left_on="year", right_on="time", how="left" 69 | ) 70 | lambda_result["random"] = 1 / len(lambda_result) 71 | 72 | # lambda test 73 | error_random_lambda = np.sqrt( 74 | lambda_result.eval("lambda_sdid - random") ** 2 75 | ).sum() 76 | error_sdid_lambda = np.sqrt( 77 | lambda_result.eval("lambda_sdid - sdid_weight") ** 2 78 | ).sum() 79 | assert error_sdid_lambda < error_random_lambda 80 | 81 | sdid_corr, _p = spearmanr( 82 | lambda_result["lambda_sdid"], lambda_result["sdid_weight"] 83 | ) 84 | assert sdid_corr >= 0.9 85 | 86 | def test_multi_treatment(self): 87 | """ 88 | トリートメントの数が変わればzetaも変わる 89 | """ 90 | 91 | test_df = fetch_CaliforniaSmoking() 92 | PRE_TEREM = [1970, 1979] 93 | POST_TEREM = [1980, 1988] 94 | 95 | treatment = [col for i, col in enumerate(test_df.columns) if i % 2 == 0] 96 | 97 | multi_sdid = SynthDID(test_df, PRE_TEREM, POST_TEREM, treatment) 98 | multi_sdid.fit(zeta_type="base") 99 | 100 | hat_omega_sdid, hat_lambda_sdid = multi_sdid.estimated_params() 101 | hat_omega = multi_sdid.estimated_params(model="sc") 102 | 103 | assert ( 104 | np.round( 105 | hat_omega_sdid.query("features != 'intercept'")["sdid_weight"].sum(), 2 106 | ) 107 | == 1.0 108 | ) 109 | 110 | assert np.round(hat_omega["sc_weight"].sum(), 2) == 1.0 111 | 112 | treatment2 = [col for i, col in enumerate(test_df.columns) if i % 3 == 0] 113 | 114 | multi_sdid2 = SynthDID(test_df, PRE_TEREM, POST_TEREM, treatment2) 115 | multi_sdid2.fit(zeta_type="base") 116 | 117 | assert multi_sdid2.zeta != multi_sdid.zeta 118 | 119 | def test_short_preterm(self): 120 | """ 121 | 極端なケースを目してみる 122 | """ 123 | 124 | test_df = fetch_CaliforniaSmoking() 125 | pre_term = [1970, 1971] 126 | post_term = [1972, 2000] 127 | 128 | treatment = ["California"] 129 | 130 | multi_sdid = SynthDID(test_df, pre_term, post_term, treatment) 131 | multi_sdid.fit(zeta_type="base") 132 | 133 | hat_omega_sdid, hat_lambda_sdid = multi_sdid.estimated_params() 134 | hat_omega = multi_sdid.estimated_params(model="sc") 135 | 136 | assert ( 137 | np.round( 138 | hat_omega_sdid.query("features != 'intercept'")["sdid_weight"].sum(), 2 139 | ) 140 | == 1.0 141 | ) 142 | 143 | assert np.round(hat_omega["sc_weight"].sum(), 2) == 1.0 144 | --------------------------------------------------------------------------------